feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for interacting with the Compute API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.compute import base_classes as compute_base
from googlecloudsdk.api_lib.compute import constants as compute_constants
from googlecloudsdk.api_lib.compute import utils as compute_utils
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.compute import flags
from googlecloudsdk.command_lib.compute import scope as compute_scope
from googlecloudsdk.command_lib.compute import scope_prompter
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
# Copy into dataproc for cleaner separation
SCOPE_ALIASES = compute_constants.SCOPES
SCOPES_HELP = compute_constants.ScopesHelp()
def ExpandScopeAliases(scopes):
"""Replace known aliases in the list of scopes provided by the user."""
scopes = scopes or []
expanded_scopes = []
for scope in scopes:
if scope in SCOPE_ALIASES:
expanded_scopes += SCOPE_ALIASES[scope]
else:
# Validate scopes server side.
expanded_scopes.append(scope)
return sorted(expanded_scopes)
def GetComputeResources(release_track, cluster_name, dataproc_region):
"""Returns a resources object with resolved GCE zone and region."""
holder = compute_base.ComputeApiHolder(release_track)
region_prop = properties.VALUES.compute.region
zone_prop = properties.VALUES.compute.zone
resources = holder.resources
# Prompt for scope if necessary. If Dataproc regional stack is used, omitting
# the zone allows the server to pick a zone
zone = properties.VALUES.compute.zone.Get()
if not zone and dataproc_region == 'global':
_, zone = scope_prompter.PromptForScope(
resource_name='cluster',
underspecified_names=[cluster_name],
scopes=[compute_scope.ScopeEnum.ZONE],
default_scope=None,
scope_lister=flags.GetDefaultScopeLister(holder.client))
if not zone:
# Still no zone, just raise error generated by this property.
zone = properties.VALUES.compute.zone.GetOrFail()
if zone:
zone_ref = resources.Parse(
zone,
params={
'project': properties.VALUES.core.project.GetOrFail,
},
collection='compute.zones')
zone_name = zone_ref.Name()
zone_prop.Set(zone_name)
region_name = compute_utils.ZoneNameToRegionName(zone_name)
region_prop.Set(region_name)
else:
# Auto zone
zone_prop.Set('')
# Set GCE region to dataproc region (which is a 1:1 mapping)
region_prop.Set(dataproc_region)
return resources
def GetDefaultServiceAccount(project_id):
"""Call Compute.Projects.Get to find project_id's default Service Account."""
holder = compute_base.ComputeApiHolder(base.ReleaseTrack.GA)
client = holder.client
# Future optimization: Limit the size with"fields='defaultServiceAccount'".
request = client.messages.ComputeProjectsGetRequest(project=project_id)
default_service_account = client.apitools_client.projects.Get(
request=request).defaultServiceAccount
log.debug('Default compute Service Account is %s.', default_service_account)
return default_service_account

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants for the dataproc tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
# TODO(b/36055865): Move defaults to the server
# Path inside of GCS bucket, where Dataproc stores metadata.
GCS_METADATA_PREFIX = 'google-cloud-dataproc-metainfo'
# Beginning of driver output files.
JOB_OUTPUT_PREFIX = 'driveroutput'
# The scopes that will be added to user-specified scopes. Used for
# documentation only. Keep in sync with server specified list.
MINIMUM_SCOPE_URIS = [
'https://www.googleapis.com/auth/devstorage.read_write',
'https://www.googleapis.com/auth/logging.write',
]
# The scopes that will be specified by default. Used fo documentation only.
# Keep in sync with server specified list.
ADDITIONAL_DEFAULT_SCOPE_URIS = [
'https://www.googleapis.com/auth/bigquery',
'https://www.googleapis.com/auth/bigtable.admin.table',
'https://www.googleapis.com/auth/bigtable.data',
'https://www.googleapis.com/auth/devstorage.full_control',
]
# The default page size for list pagination.
DEFAULT_PAGE_SIZE = 100
ALLOW_ZERO_WORKERS_PROPERTY = 'dataproc:dataproc.allow.zero.workers'
ENABLE_NODE_GROUPS_PROPERTY = 'dataproc:dataproc.nodegroups.enabled'
ENABLE_DYNAMIC_MULTI_TENANCY_PROPERTY = (
'dataproc:dataproc.dynamic.multi.tenancy.enabled'
)

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common stateful utilities for the gcloud dataproc tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
from googlecloudsdk.core import resources
class Dataproc(object):
"""Stateful utility for calling Dataproc APIs.
While this currently could all be static. It is encapsulated in a class to
support API version switching in future.
"""
def __init__(self, release_track=base.ReleaseTrack.GA):
super(Dataproc, self).__init__()
self.release_track = release_track
self.api_version = 'v1'
self._client = None
self._resources = None
@property
def client(self):
if self._client is None:
self._client = apis.GetClientInstance('dataproc', self.api_version)
return self._client
@property
def messages(self):
return apis.GetMessagesModule('dataproc', self.api_version)
@property
def resources(self):
if self._resources is None:
self._resources = resources.REGISTRY.Clone()
self._resources.RegisterApiByName('dataproc', self.api_version)
return self._resources
@property
def terminal_job_states(self):
return [
self.messages.JobStatus.StateValueValuesEnum.CANCELLED,
self.messages.JobStatus.StateValueValuesEnum.DONE,
self.messages.JobStatus.StateValueValuesEnum.ERROR,
]
def GetCreateClusterRequest(self,
cluster,
project_id,
region,
request_id,
action_on_failed_primary_workers=None):
"""Gets the CreateClusterRequest for the appropriate api version.
Args :
cluster : Dataproc cluster to be created.
project_id: The ID of the Google Cloud Platform project that the cluster
belongs to.
region : The Dataproc region in which to handle the request.
request_id : A unique ID used to identify the request.
action_on_failed_primary_workers : Supported only for v1 api.
Raises :
ValueError : if non-None action_on_failed_primary_workers is passed for
v1beta2 api.
Returns :
DataprocProjectsRegionsClustersCreateRequest
"""
if action_on_failed_primary_workers is None:
return self.messages.DataprocProjectsRegionsClustersCreateRequest(
cluster=cluster,
projectId=project_id,
region=region,
requestId=request_id)
if self.api_version == 'v1beta2':
raise ValueError(
'action_on_failed_primary_workers is not supported for v1beta2 api')
return self.messages.DataprocProjectsRegionsClustersCreateRequest(
cluster=cluster,
projectId=project_id,
region=region,
requestId=request_id,
actionOnFailedPrimaryWorkers=action_on_failed_primary_workers)
def GetRegionsWorkflowTemplate(self, template, version=None):
"""Gets workflow template from dataproc.
Args:
template: workflow template resource that contains template name and id.
version: version of the workflow template to get.
Returns:
WorkflowTemplate object that contains the workflow template info.
Raises:
ValueError: if version cannot be converted to a valid integer.
"""
messages = self.messages
get_request = messages.DataprocProjectsRegionsWorkflowTemplatesGetRequest(
name=template.RelativeName())
if version:
get_request.version = int(version)
return self.client.projects_regions_workflowTemplates.Get(get_request)

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Refine server response for display."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import encoding
from googlecloudsdk.api_lib.dataproc import util
class DisplayHelper(util.Bunch):
"""Refine server response for display."""
def __init__(self, job):
super(DisplayHelper, self).__init__(encoding.MessageToDict(job))
self._job = job
@property
def jobType(self):
return self.getTruncatedFieldNameBySuffix('Job')
@property
def batchType(self):
return self.getTruncatedFieldNameBySuffix('Batch')
@property
def sessionType(self):
return self.getTruncatedFieldNameBySuffix('Session')
def getTruncatedFieldNameBySuffix(self, suffix):
"""Get a field name by suffix and truncate it.
The one_of fields in server response have their type name as field key.
One can retrieve the name of those fields by iterating through all the
fields.
Args:
suffix: the suffix to match.
Returns:
The first matched truncated field name.
Raises:
AttributeError: Error occur when there is no match for the suffix.
Usage Example:
In server response:
{
...
"sparkJob":{
...
}
...
}
type = helper.getTruncatedFieldNameBySuffix('Job')
"""
for field in [field.name for field in self._job.all_fields()]:
if field.endswith(suffix):
token, _, _ = field.rpartition(suffix)
if self._job.get_assigned_value(field):
return token
raise AttributeError('Response has no field with {} as suffix.'
.format(suffix))

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for user-visible error exceptions to raise in the CLI."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import exceptions
class Error(exceptions.Error):
"""Exceptions for Deployment Manager errors."""
class ArgumentError(Error):
"""Command argument error."""
class JobError(Error):
"""Job encountered an error."""
class JobTimeoutError(JobError):
"""Job timed out."""
class OperationError(Error):
"""Operation encountered an error."""
class OperationTimeoutError(OperationError):
"""Operation timed out."""
class ParseError(Error):
"""File parsing error."""
class FileUploadError(Error):
"""File upload error."""
class ObjectReadError(Error):
"""Cloud Storage Object read error."""
class ValidationError(Error):
"""Error while validating YAML against schema."""
class PersonalAuthError(Exception):
"""Error while establishing a personal auth session."""
class GkeClusterGetError(Error):
"""Error while getting a GKE Cluster."""
def __init__(self, cause):
super(GkeClusterGetError, self).__init__(
'Error while getting the GKE Cluster: {0}'.format(cause))
class GkeClusterMissingWorkloadIdentityError(Error):
"""GKE Cluster is not Workload Identity enabled."""
def __init__(self, gke_cluster_ref):
super(GkeClusterMissingWorkloadIdentityError, self).__init__()
self.gke_cluster_ref = gke_cluster_ref
def __str__(self):
return (
'GKE Cluster "{0}" does not have Workload Identity enabled. Dataproc '
'on GKE requires the GKE Cluster to have Workload Identity enabled. '
'See '
'https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity'
).format(self.gke_cluster_ref.RelativeName())

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for interacting with the GKE API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.container import api_adapter as gke_api_adapter
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.core import log
def GetGkeClusterIsWorkloadIdentityEnabled(project, location, cluster):
"""Determines if the GKE cluster is Workload Identity enabled."""
gke_cluster = _GetGkeCluster(project, location, cluster)
workload_identity_config = gke_cluster.workloadIdentityConfig
if not workload_identity_config:
log.debug('GKE cluster does not have a workloadIdentityConfig.')
return False
workload_pool = workload_identity_config.workloadPool
if not workload_pool:
log.debug('GKE cluster\'s workloadPool is the empty string.')
return False
return True
def _GetGkeCluster(project, location, cluster):
"""Gets the GKE cluster."""
gke_client = gke_api_adapter.NewV1APIAdapter()
try:
return gke_client.GetCluster(
gke_client.ParseCluster(
name=cluster, location=location, project=project))
except Exception as e:
raise exceptions.GkeClusterGetError(e)

View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for interacting with the IAM API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.iam import util as iam_api
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.core import log
# This is the maximum version of the IAM policy that Dataproc supports.
MAX_LIBRARY_IAM_SUPPORTED_VERSION = 3
def AddIamPolicyBindings(resource, members, role):
"""Adds IAM policy bindings for members with the role on resource."""
iam_client, iam_messages = iam_api.GetClientAndMessages()
request = iam_messages.IamProjectsServiceAccountsGetIamPolicyRequest(
resource=resource)
iam_policy = iam_client.projects_serviceAccounts.GetIamPolicy(request=request)
binding_updated = False
for member in members:
binding_updated |= iam_util.AddBindingToIamPolicy(iam_messages.Binding,
iam_policy, member, role)
if not binding_updated:
log.debug('Skipped setting IAM policy, no changes are needed.')
return
log.debug('Setting the updated IAM policy.')
set_request = iam_messages.IamProjectsServiceAccountsSetIamPolicyRequest(
resource=resource,
setIamPolicyRequest=iam_messages.SetIamPolicyRequest(policy=iam_policy))
iam_policy = iam_client.projects_serviceAccounts.SetIamPolicy(
request=set_request)

View File

@@ -0,0 +1,199 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract waiter utility for api_lib.util.waiter.py.
Abstract waiter utility class for api_lib.util.waiter.WaitFor.
This class is the base class for poller that need to stream output and poll on
dataproc's operation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
from googlecloudsdk.api_lib.dataproc import storage_helpers
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import log
from googlecloudsdk.core.console import console_attr
# TODO(b/191187715): Migrate current job poller to extend this class.
class AbstractOperationStreamerPoller(waiter.OperationPoller):
"""Base abstract poller class for dataproc operation.
Base abstract poller class for dataproc operation. The class is designed to
stream remote output from GCS.
Pass TrackerUpdateFunction to waiter.WaitFor's tracker_update_func parameter
to stream remote output.
"""
def __init__(self, dataproc):
"""Poller for batch workload.
Args:
dataproc: A api_lib.dataproc.Dataproc instance.
"""
self.saved_stream_uri = None
self.driver_log_streamer = None
self.dataproc = dataproc
@abc.abstractmethod
def IsDone(self, poll_result):
"""Determines if the poll result is done.
Determines if the poll result is done. This is a null implementation that
simply returns False. Sub class should override this and provide concrete
checking logic.
Overrides.
Args:
poll_result: Poll result returned from Poll function.
Returns:
True if the remote resource is done, or False otherwise.
"""
return False
@abc.abstractmethod
def Poll(self, ref):
"""Fetches remote resource.
Fetches remote resource. This is a null implementation that simply returns
None. Sub class should override this and provide concrete fetching logic.
Overrides.
Args:
ref: Resource reference. The same argument passed to waiter.WaitFor.
Returns:
None. Sub class should override this and return the actual fetched
resource.
"""
return None
@abc.abstractmethod
def _GetOutputUri(self, poll_result):
"""Gets output uri from poll result.
Gets output uri from poll result. This is a null implementation that
returns None. Sub class should override this and return actual output uri
for output streamer, or returns None if something goes wrong and there is
no output uri in the poll result.
Args:
poll_result: Poll result returned by Poll.
Returns:
None. Sub class should override this and returns actual output uri, or
None when something goes wrong.
"""
return None
@abc.abstractmethod
def _GetResult(self, poll_result):
"""Returns operation result to caller.
This function is called after GetResult streams remote output.
This is a null implementation that simply returns None. Sub class should
override this and provide actual _GetResult logic.
Args:
poll_result: Poll result returned from Poll.
Returns:
None. Sub class should override this and return actual result.
"""
return None
def GetResult(self, poll_result):
"""Returns result for remote resource.
This function first stream remote output to user, then returns operation
result by calling _GetResult.
Overrides.
Args:
poll_result: Poll result returned by Poll.
Returns:
Wahtever returned from _GetResult.
"""
# Stream the remaining outputs.
# There won't be new remote output, so no need to poll on it.
# Let the streamer stream utill it ends.
self.TrackerUpdateFunction(None, poll_result, None)
return self._GetResult(poll_result)
def TrackerUpdateFunction(self, tracker, poll_result, status):
"""Custom tracker function which gets called after every tick.
This gets called whenever progress tracker gets a tick. However we want to
stream remote output to users instead of showing a progress tracker.
Args:
tracker: Progress tracker instance. Not being used.
poll_result: Result from Poll function.
status: Status argument that is supposed to pass to the progress tracker
instance. Not being used here as well.
"""
self._CheckStreamer(poll_result)
self._StreamOutput()
def _StreamOutput(self):
if self.driver_log_streamer and self.driver_log_streamer.open:
self.driver_log_streamer.ReadIntoWritable(log.err)
def _CheckStreamer(self, poll_result):
"""Checks if need to init a new output streamer.
Checks if need to init a new output streamer.
Remote may fail; switch to new output uri.
Invalidate the streamer instance and init a new one if necessary.
Args:
poll_result: Poll result returned from Poll.
"""
# Mimic current job waiting behavior to print equal signs across the screen.
def _PrintEqualsLineAccrossScreen():
attr = console_attr.GetConsoleAttr()
log.err.Print('=' * attr.GetTermSize()[0])
# pylint: disable=assignment-from-none
uri = self._GetOutputUri(poll_result)
# pylint: enable=assignment-from-none
if not uri:
# Remote resource not ready, nothing to check.
return
# Invalidate current streamer if remote output uri changed.
if self.saved_stream_uri and self.saved_stream_uri != uri:
self.driver_log_streamer = None
self.saved_stream_uri = None
_PrintEqualsLineAccrossScreen()
log.warning("Attempt failed. Streaming new attempt's output.")
_PrintEqualsLineAccrossScreen()
# Init a new streamer if there is no active streamer.
if not self.driver_log_streamer:
self.saved_stream_uri = uri
self.driver_log_streamer = storage_helpers.StorageObjectSeriesStream(uri)

View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper class for generating Cloud Logging URLs for Dataproc resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
from googlecloudsdk.core import properties
from six.moves.urllib import parse
def get_plain_batch_logging_url():
"""Returns the base URL for the Cloud Logging console.
This is used when parsing batch resource failed.
"""
logging_base = 'https://console.cloud.google.com/logs/query'
batch_resource_filter = 'resource.type="cloud_dataproc_batch"'
return '{}?{}'.format(
logging_base,
parse.urlencode({
'query': batch_resource_filter,
}),
)
def get_batch_logging_url(batch):
"""Returns a Cloud Logging URL for the given batch.
Args:
batch: The batch to get the Cloud Logging URL for.
Returns:
A Cloud Logging URL for the given batch or a plain url without batch info.
"""
match = re.match(
r'projects/(?P<project_id>[^/]+)/locations/[^/]+/batches/(?P<batch_id>[^/]+)',
batch.name,
)
if not match:
return get_plain_batch_logging_url()
project_id = match.group('project_id')
batch_id = match.group('batch_id')
logging_base = 'https://console.cloud.google.com/logs/query'
batch_resource_filter = 'resource.type="cloud_dataproc_batch"'
project_query = f'project={project_id}'
batch_id_filter = f'resource.labels.batch_id="{batch_id}"'
universe_domain = properties.VALUES.core.universe_domain.Get()
driver_output_filter = f'log_name="projects/{project_id}/logs/dataproc.{universe_domain}%2Foutput"'
return '{}?{}&{}'.format(
logging_base,
parse.urlencode({
'query': (
batch_resource_filter
+ '\n'
+ batch_id_filter
+ '\n'
+ driver_output_filter
),
}),
project_query,
)
def get_plain_batches_list_url():
"""Returns the base URL for the Dataproc Batches console.
This is used when parsing batch resource failed.
"""
dataproc_batches_base = 'https://console.cloud.google.com/dataproc/batches'
return dataproc_batches_base
def get_dataproc_batch_url(batch):
"""Returns a Dataproc Batch URL for the given batch."""
match = re.match(
r'projects/(?P<project_id>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)',
batch.name,
)
if not match:
return get_plain_batches_list_url()
project_id = match.group('project_id')
batch_id = match.group('batch_id')
location = match.group('location')
dataproc_batch_url = f'https://console.cloud.google.com/dataproc/batches/{location}/{batch_id}/summary'
project_query = f'project={project_id}'
return dataproc_batch_url + '?' + project_query

View File

@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Waiter utility for api_lib.util.waiter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import exceptions as apitools_exceptions
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.dataproc import util
from googlecloudsdk.api_lib.dataproc.poller import (
abstract_operation_streamer_poller as dataproc_poller_base,
)
from googlecloudsdk.core import log
class GceBatchPoller(
dataproc_poller_base.AbstractOperationStreamerPoller,
):
"""Poller for GCE based batch workload."""
def IsDone(self, batch):
"""See base class."""
if batch and batch.state in (
self.dataproc.messages.Batch.StateValueValuesEnum.SUCCEEDED,
self.dataproc.messages.Batch.StateValueValuesEnum.CANCELLED,
self.dataproc.messages.Batch.StateValueValuesEnum.FAILED,
):
return True
return False
def Poll(self, batch_ref):
"""See base class."""
request = self.dataproc.messages.DataprocProjectsLocationsBatchesGetRequest(
name=batch_ref
)
try:
return self.dataproc.client.projects_locations_batches.Get(request)
except apitools_exceptions.HttpError as error:
log.warning('Get Batch failed:\n{}'.format(error))
if util.IsClientHttpException(error):
# Stop polling if encounter client Http error (4xx).
raise
def _GetResult(self, batch):
"""Handles errors.
Error handling for batch jobs. This happen after the batch reaches one of
the complete states.
Overrides.
Args:
batch: The batch resource.
Returns:
None. The result is directly output to log.err.
Raises:
JobTimeoutError: When waiter timed out.
JobError: When remote batch job is failed.
"""
if not batch:
# Batch resource is None but polling is considered done.
# This only happens when the waiter timed out.
raise exceptions.JobTimeoutError('Timed out while waiting for batch job.')
if (
batch.state
== self.dataproc.messages.Batch.StateValueValuesEnum.SUCCEEDED
):
if not self.driver_log_streamer:
log.warning('Expected batch job output not found.')
elif self.driver_log_streamer.open:
# Remote output didn't end correctly.
log.warning(
'Batch job terminated, but output did not finish streaming.'
)
elif (
batch.state
== self.dataproc.messages.Batch.StateValueValuesEnum.CANCELLED
):
log.warning('Batch job is CANCELLED.')
else:
err_message = 'Batch job is FAILED.'
if batch.stateMessage:
err_message = '{} Detail: {}'.format(err_message, batch.stateMessage)
if err_message[-1] != '.':
err_message += '.'
err_message += '\n'
err_message += (
'Running auto diagnostics on the batch. It may take few '
'minutes before diagnostics output is available. Please '
"check diagnostics output by running 'gcloud dataproc "
"batches describe' command."
)
raise exceptions.JobError(err_message)
# Nothing to return, since the result is directly output to users.
return None
def _GetOutputUri(self, batch):
"""See base class."""
if batch and batch.runtimeInfo and batch.runtimeInfo.outputUri:
return batch.runtimeInfo.outputUri
return None

View File

@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Waiter utility for api_lib.util.waiter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import logging
from apitools.base.py import exceptions as apitools_exceptions
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.dataproc import util
from googlecloudsdk.api_lib.dataproc.poller import cloud_console_url_helper
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import log
class RmBatchPoller(waiter.OperationPoller):
"""Poller for resource manager batches.
This should be used for spark version 3+, and Ray version 1+.
"""
def __init__(self, dataproc):
self.dataproc = dataproc
self.fist_tick_message_printed = False
def IsDone(self, batch):
"""See base class."""
if batch and batch.state in (
self.dataproc.messages.Batch.StateValueValuesEnum.SUCCEEDED,
self.dataproc.messages.Batch.StateValueValuesEnum.CANCELLED,
self.dataproc.messages.Batch.StateValueValuesEnum.FAILED,
):
return True
return False
def Poll(self, batch_ref):
"""See base class."""
request = self.dataproc.messages.DataprocProjectsLocationsBatchesGetRequest(
name=batch_ref
)
try:
return self.dataproc.client.projects_locations_batches.Get(request)
except apitools_exceptions.HttpError as error:
log.warning('Get Batch failed:\n{}'.format(error))
if util.IsClientHttpException(error):
# Stop polling if encounter client Http error (4xx).
raise
def GetResult(self, batch):
"""Handles errors.
Error handling for batch jobs. This happen after the batch reaches one of
the complete states.
Overrides.
Args:
batch: The batch resource.
Returns:
None. The result is directly output to log.err.
Raises:
JobTimeoutError: When waiter timed out.
JobError: When remote batch job is failed.
"""
if not batch:
# Batch resource is None but polling is considered done.
# This only happens when the waiter timed out.
raise exceptions.JobTimeoutError('Timed out while waiting for batch job.')
if (
batch.state
== self.dataproc.messages.Batch.StateValueValuesEnum.CANCELLED
):
log.warning('Batch job is CANCELLED.')
elif (
batch.state == self.dataproc.messages.Batch.StateValueValuesEnum.FAILED
):
err_message = 'Batch job is FAILED.'
if batch.stateMessage:
err_message = '{} Detail: {}'.format(err_message, batch.stateMessage)
if err_message[-1] != '.':
err_message += '.'
err_message += '\n'
err_message += (
'Running auto diagnostics on the batch. It may take few '
'minutes before diagnostics output is available. Please '
"check diagnostics output by running 'gcloud dataproc "
"batches describe' command."
)
raise exceptions.JobError(err_message)
# Nothing to return, since the result is directly output to users.
return None
def TrackerUpdateFunction(self, tracker, poll_result, status):
"""Prints links to cloud console after the first success pull."""
if not self.fist_tick_message_printed:
self.fist_tick_message_printed = True
cloud_logging_url = cloud_console_url_helper.get_batch_logging_url(
poll_result
)
dataproc_batch_url = cloud_console_url_helper.get_dataproc_batch_url(
poll_result
)
log_level = log.GetVerbosity()
log.SetVerbosity(logging.INFO)
log.info(
'Please check the driver output in Cloud Logging: %s. (The log can'
' take a few minutes to show up.) You can visit the batch resource'
' at %s',
cloud_logging_url,
dataproc_batch_url,
)
log.info('Waiting for the batch to complete.')
log.SetVerbosity(log_level)

View File

@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Waiter utility for api_lib.util.waiter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import exceptions as apitools_exceptions
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.dataproc import util
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import log
class SessionPoller(waiter.OperationPoller):
"""Poller for session workload."""
def __init__(self, dataproc):
"""Poller for session workload."""
self.dataproc = dataproc
def IsDone(self, session):
"""See base class."""
return session and session.state in (
self.dataproc.messages.Session.StateValueValuesEnum.ACTIVE,
self.dataproc.messages.Session.StateValueValuesEnum.FAILED)
def Poll(self, session_ref):
"""See base class."""
request = (
self.dataproc.messages.DataprocProjectsLocationsSessionsGetRequest(
name=session_ref))
try:
return self.dataproc.client.projects_locations_sessions.Get(request)
except apitools_exceptions.HttpError as error:
log.warning('Get session failed:\n{}'.format(error))
if util.IsClientHttpException(error):
# Stop polling if encounter client Http error (4xx).
raise
def GetResult(self, session):
"""Handles errors.
Error handling for sessions. This happen after the session reaches one of
the complete states.
Overrides.
Args:
session: The session resource.
Returns:
None. The result is directly output to log.err.
Raises:
OperationTimeoutError: When waiter timed out.
OperationError: When remote session creation is failed.
"""
if not session:
# Session resource is None but polling is considered done.
# This only happens when the waiter timed out.
raise exceptions.OperationTimeoutError(
'Timed out while waiting for session creation.')
if (session.state ==
self.dataproc.messages.Session.StateValueValuesEnum.FAILED):
err_message = 'Session creation is FAILED.'
if session.stateMessage:
err_message = '{} Detail: {}'.format(err_message, session.stateMessage)
if err_message[-1] != '.':
err_message += '.'
raise exceptions.OperationError(err_message)
# Nothing to return.
return None
def TrackerUpdateFunction(self, tracker, poll_result, status):
pass

View File

@@ -0,0 +1,342 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for accessing GCS.
Bulk object uploads and downloads use methods that shell out to gsutil.
Lightweight metadata / streaming operations use the StorageClient class.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import io
import os
import sys
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import transfer
from googlecloudsdk.api_lib.dataproc import exceptions as dp_exceptions
from googlecloudsdk.api_lib.storage import storage_api
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.api_lib.util import apis as core_apis
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six.moves.urllib.parse
# URI scheme for GCS.
STORAGE_SCHEME = 'gs'
# Timeout for individual socket connections. Matches gsutil.
HTTP_TIMEOUT = 60
# Fix urlparse for storage paths.
# This allows using urljoin in other files (that import this).
six.moves.urllib.parse.uses_relative.append(STORAGE_SCHEME)
six.moves.urllib.parse.uses_netloc.append(STORAGE_SCHEME)
def Upload(files, destination, storage_client=None):
# TODO(b/109938541): Remove gsutil implementation after the new
# implementation seems stable.
use_gsutil = properties.VALUES.storage.use_gsutil.GetBool()
if use_gsutil:
_UploadGsutil(files, destination)
else:
_UploadStorageClient(files, destination, storage_client=storage_client)
def _UploadStorageClient(files, destination, storage_client=None):
"""Upload a list of local files to GCS.
Args:
files: The list of local files to upload.
destination: A GCS "directory" to copy the files into.
storage_client: Storage api client used to copy files to gcs.
"""
client = storage_client or storage_api.StorageClient()
for file_to_upload in files:
file_name = os.path.basename(file_to_upload)
dest_url = os.path.join(destination, file_name)
dest_object = storage_util.ObjectReference.FromUrl(dest_url)
try:
client.CopyFileToGCS(file_to_upload, dest_object)
except exceptions.BadFileException as err:
raise dp_exceptions.FileUploadError(
"Failed to upload files ['{}'] to '{}': {}".format(
"', '".join(files), destination, err))
def _UploadGsutil(files, destination):
"""Upload a list of local files to GCS.
Args:
files: The list of local files to upload.
destination: A GCS "directory" to copy the files into.
"""
args = files
args += [destination]
exit_code = storage_util.RunGsutilCommand('cp', args)
if exit_code != 0:
raise dp_exceptions.FileUploadError(
"Failed to upload files ['{0}'] to '{1}' using gsutil.".format(
"', '".join(files), destination))
def GetBucket(bucket, storage_client=None):
"""Gets a bucket if it exists.
Args:
bucket: The bucket name.
storage_client: Storage client instance.
Returns:
A bucket message, or None if it doesn't exist.
"""
client = storage_client or storage_api.StorageClient()
try:
return client.GetBucket(bucket)
except storage_api.BucketNotFoundError:
return None
def CreateBucketIfNotExists(bucket, region, storage_client=None, project=None):
"""Creates a bucket.
Creates a bucket in the specified region. If the region is None, the bucket
will be created in global region.
Args:
bucket: Name of bucket to create.
region: Region to create bucket in.
storage_client: Storage client instance.
project: The project to create the bucket in. If None, current Cloud SDK
project is used.
"""
client = storage_client or storage_api.StorageClient()
client.CreateBucketIfNotExists(bucket, location=region, project=project)
def ReadObject(object_url, storage_client=None):
"""Reads an object's content from GCS.
Args:
object_url: The URL of the object to be read. Must have "gs://" prefix.
storage_client: Storage api client used to read files from gcs.
Raises:
ObjectReadError:
If the read of GCS object is not successful.
Returns:
A str for the content of the GCS object.
"""
client = storage_client or storage_api.StorageClient()
object_ref = storage_util.ObjectReference.FromUrl(object_url)
try:
bytes_io = client.ReadObject(object_ref)
wrapper = io.TextIOWrapper(bytes_io, encoding='utf-8')
return wrapper.read()
except exceptions.BadFileException:
raise dp_exceptions.ObjectReadError(
"Failed to read file '{0}'.".format(object_url))
def GetObjectRef(path, messages):
"""Build an Object proto message from a GCS path."""
resource = resources.REGISTRY.ParseStorageURL(path)
return messages.Object(bucket=resource.bucket, name=resource.object)
class StorageClient(object):
"""Micro-client for accessing GCS."""
# TODO(b/36050236): Add application-id.
def __init__(self):
self.client = core_apis.GetClientInstance('storage', 'v1')
self.messages = core_apis.GetMessagesModule('storage', 'v1')
def _GetObject(self, object_ref, download=None):
request = self.messages.StorageObjectsGetRequest(
bucket=object_ref.bucket, object=object_ref.name)
try:
return self.client.objects.Get(request=request, download=download)
except apitools_exceptions.HttpNotFoundError:
# TODO(b/36052479): Clean up error handling. Handle 403s.
return None
def GetObject(self, object_ref):
"""Get the object metadata of a GCS object.
Args:
object_ref: A proto message of the object to fetch. Only the bucket and
name need be set.
Raises:
HttpError:
If the responses status is not 2xx or 404.
Returns:
The object if it exists otherwise None.
"""
return self._GetObject(object_ref)
def BuildObjectStream(self, stream, object_ref):
"""Build an apitools Download from a stream and a GCS object reference.
Note: This will always succeed, but HttpErrors with downloading will be
raised when the download's methods are called.
Args:
stream: An Stream-like object that implements write(<string>) to write
into.
object_ref: A proto message of the object to fetch. Only the bucket and
name need be set.
Returns:
The download.
"""
download = transfer.Download.FromStream(
stream, total_size=object_ref.size, auto_transfer=False)
self._GetObject(object_ref, download=download)
return download
class StorageObjectSeriesStream(object):
"""I/O Stream-like class for communicating via a sequence of GCS objects."""
def __init__(self, path, storage_client=None):
"""Construct a StorageObjectSeriesStream for a specific gcs path.
Args:
path: A GCS object prefix which will be the base of the objects used to
communicate across the channel.
storage_client: a StorageClient for accessing GCS.
Returns:
The constructed stream.
"""
self._base_path = path
self._gcs = storage_client or StorageClient()
self._open = True
# Index of current object in series.
self._current_object_index = 0
# Current position in bytes in the current file.
self._current_object_pos = 0
@property
def open(self):
"""Whether the stream is open."""
return self._open
def Close(self):
"""Close the stream."""
self._open = False
def _AssertOpen(self):
if not self.open:
raise ValueError('I/O operation on closed stream.')
def _GetObject(self, i):
"""Get the ith object in the series."""
path = '{0}.{1:09d}'.format(self._base_path, i)
return self._gcs.GetObject(GetObjectRef(path, self._gcs.messages))
def ReadIntoWritable(self, writable, n=sys.maxsize):
"""Read from this stream into a writable.
Reads at most n bytes, or until it sees there is not a next object in the
series. This will block for the duration of each object's download,
and possibly indefinitely if new objects are being added to the channel
frequently enough.
Args:
writable: The stream-like object that implements write(<string>) to
write into.
n: A maximum number of bytes to read. Defaults to sys.maxsize
(usually ~4 GB).
Raises:
ValueError: If the stream is closed or objects in the series are
detected to shrink.
Returns:
The number of bytes read.
"""
self._AssertOpen()
bytes_read = 0
object_info = None
max_bytes_to_read = n
while bytes_read < max_bytes_to_read:
# Cache away next object first.
next_object_info = self._GetObject(self._current_object_index + 1)
# If next object exists always fetch current object to get final size.
if not object_info or next_object_info:
try:
object_info = self._GetObject(self._current_object_index)
except apitools_exceptions.HttpError as error:
log.warning('Failed to fetch GCS output:\n%s', error)
break
if not object_info:
# Nothing to read yet.
break
new_bytes_available = object_info.size - self._current_object_pos
if new_bytes_available < 0:
raise ValueError('Object [{0}] shrunk.'.format(object_info.name))
if object_info.size == 0:
# There are no more objects to read
self.Close()
break
bytes_left_to_read = max_bytes_to_read - bytes_read
new_bytes_to_read = min(bytes_left_to_read, new_bytes_available)
if new_bytes_to_read > 0:
# Download range.
download = self._gcs.BuildObjectStream(writable, object_info)
download.GetRange(
self._current_object_pos,
self._current_object_pos + new_bytes_to_read - 1)
self._current_object_pos += new_bytes_to_read
bytes_read += new_bytes_to_read
# Correct since we checked for next object before getting current
# object's size.
object_finished = (
next_object_info and self._current_object_pos == object_info.size)
if object_finished:
object_info = next_object_info
self._current_object_index += 1
self._current_object_pos = 0
continue
else:
# That is all there is to read at this time.
break
return bytes_read

View File

@@ -0,0 +1,969 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for the gcloud dataproc tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import base64
import hashlib
import json
import os
import subprocess
import tempfile
import time
import uuid
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.dataproc import storage_helpers
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import requests
from googlecloudsdk.core.console import console_attr
from googlecloudsdk.core.console import console_io
from googlecloudsdk.core.console import progress_tracker
from googlecloudsdk.core.credentials import creds as c_creds
from googlecloudsdk.core.credentials import store as c_store
from googlecloudsdk.core.util import retry
import six
SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'schemas')
def FormatRpcError(error):
"""Returns a printable representation of a failed Google API's status.proto.
Args:
error: the failed Status to print.
Returns:
A ready-to-print string representation of the error.
"""
log.debug('Error:\n' + encoding.MessageToJson(error))
return error.message
def WaitForResourceDeletion(request_method,
resource_ref,
message,
timeout_s=60,
poll_period_s=5):
"""Poll Dataproc resource until it no longer exists."""
with progress_tracker.ProgressTracker(message, autotick=True):
start_time = time.time()
while timeout_s > (time.time() - start_time):
try:
request_method(resource_ref)
except apitools_exceptions.HttpNotFoundError:
# Object deleted
return
except apitools_exceptions.HttpError as error:
log.debug('Get request for [{0}] failed:\n{1}', resource_ref, error)
# Do not retry on 4xx errors
if IsClientHttpException(error):
raise
time.sleep(poll_period_s)
raise exceptions.OperationTimeoutError(
'Deleting resource [{0}] timed out.'.format(resource_ref))
def GetUniqueId():
return uuid.uuid4().hex
class Bunch(object):
"""Class that converts a dictionary to javascript like object.
For example:
Bunch({'a': {'b': {'c': 0}}}).a.b.c == 0
"""
def __init__(self, dictionary):
for key, value in six.iteritems(dictionary):
if isinstance(value, dict):
value = Bunch(value)
self.__dict__[key] = value
def AddJvmDriverFlags(parser):
parser.add_argument(
'--jar',
dest='main_jar',
help='The HCFS URI of jar file containing the driver jar.')
parser.add_argument(
'--class',
dest='main_class',
help=('The class containing the main method of the driver. Must be in a'
' provided jar or jar that is already on the classpath'))
def IsClientHttpException(http_exception):
"""Returns true if the http exception given is an HTTP 4xx error."""
return http_exception.status_code >= 400 and http_exception.status_code < 500
# TODO(b/36056506): Use api_lib.utils.waiter
def WaitForOperation(dataproc, operation, message, timeout_s, poll_period_s=5):
"""Poll dataproc Operation until its status is done or timeout reached.
Args:
dataproc: wrapper for Dataproc messages, resources, and client
operation: Operation, message of the operation to be polled.
message: str, message to display to user while polling.
timeout_s: number, seconds to poll with retries before timing out.
poll_period_s: number, delay in seconds between requests.
Returns:
Operation: the return value of the last successful operations.get
request.
Raises:
OperationError: if the operation times out or finishes with an error.
"""
request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest(
name=operation.name)
log.status.Print('Waiting on operation [{0}].'.format(operation.name))
start_time = time.time()
warnings_so_far = 0
is_tty = console_io.IsInteractive(error=True)
tracker_separator = '\n' if is_tty else ''
def _LogWarnings(warnings):
new_warnings = warnings[warnings_so_far:]
if new_warnings:
# Drop a line to print nicely with the progress tracker.
log.err.write(tracker_separator)
for warning in new_warnings:
log.warning(warning)
with progress_tracker.ProgressTracker(message, autotick=True):
while timeout_s > (time.time() - start_time):
try:
operation = dataproc.client.projects_regions_operations.Get(request)
metadata = ParseOperationJsonMetadata(
operation.metadata, dataproc.messages.ClusterOperationMetadata)
_LogWarnings(metadata.warnings)
warnings_so_far = len(metadata.warnings)
if operation.done:
break
except apitools_exceptions.HttpError as http_exception:
# Do not retry on 4xx errors.
if IsClientHttpException(http_exception):
raise
time.sleep(poll_period_s)
metadata = ParseOperationJsonMetadata(
operation.metadata, dataproc.messages.ClusterOperationMetadata)
_LogWarnings(metadata.warnings)
if not operation.done:
raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format(
operation.name))
elif operation.error:
raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
operation.name, FormatRpcError(operation.error)))
log.info('Operation [%s] finished after %.3f seconds', operation.name,
(time.time() - start_time))
return operation
def PrintWorkflowMetadata(metadata, status, operations, errors):
"""Print workflow and job status for the running workflow template.
This method will detect any changes of state in the latest metadata and print
all the new states in a workflow template.
For example:
Workflow template template-name RUNNING
Creating cluster: Operation ID create-id.
Job ID job-id-1 RUNNING
Job ID job-id-1 COMPLETED
Deleting cluster: Operation ID delete-id.
Workflow template template-name DONE
Args:
metadata: Dataproc WorkflowMetadata message object, contains the latest
states of a workflow template.
status: Dictionary, stores all jobs' status in the current workflow
template, as well as the status of the overarching workflow.
operations: Dictionary, stores cluster operation status for the workflow
template.
errors: Dictionary, stores errors from the current workflow template.
"""
# Key chosen to avoid collision with job ids, which are at least 3 characters.
template_key = 'wt'
if template_key not in status or metadata.state != status[template_key]:
if metadata.template is not None:
log.status.Print('WorkflowTemplate [{0}] {1}'.format(
metadata.template, metadata.state))
else:
# Workflows instantiated inline do not store an id in their metadata.
log.status.Print('WorkflowTemplate {0}'.format(metadata.state))
status[template_key] = metadata.state
if metadata.createCluster != operations['createCluster']:
if hasattr(metadata.createCluster,
'error') and metadata.createCluster.error is not None:
log.status.Print(metadata.createCluster.error)
elif hasattr(metadata.createCluster,
'done') and metadata.createCluster.done is not None:
log.status.Print('Created cluster: {0}.'.format(metadata.clusterName))
elif hasattr(
metadata.createCluster,
'operationId') and metadata.createCluster.operationId is not None:
log.status.Print('Creating cluster: Operation ID [{0}].'.format(
metadata.createCluster.operationId))
operations['createCluster'] = metadata.createCluster
if hasattr(metadata.graph, 'nodes'):
for node in metadata.graph.nodes:
if not node.jobId:
continue
if node.jobId not in status or status[node.jobId] != node.state:
log.status.Print('Job ID {0} {1}'.format(node.jobId, node.state))
status[node.jobId] = node.state
if node.error and (node.jobId not in errors or
errors[node.jobId] != node.error):
log.status.Print('Job ID {0} error: {1}'.format(node.jobId, node.error))
errors[node.jobId] = node.error
if metadata.deleteCluster != operations['deleteCluster']:
if hasattr(metadata.deleteCluster,
'error') and metadata.deleteCluster.error is not None:
log.status.Print(metadata.deleteCluster.error)
elif hasattr(metadata.deleteCluster,
'done') and metadata.deleteCluster.done is not None:
log.status.Print('Deleted cluster: {0}.'.format(metadata.clusterName))
elif hasattr(
metadata.deleteCluster,
'operationId') and metadata.deleteCluster.operationId is not None:
log.status.Print('Deleting cluster: Operation ID [{0}].'.format(
metadata.deleteCluster.operationId))
operations['deleteCluster'] = metadata.deleteCluster
# TODO(b/36056506): Use api_lib.utils.waiter
def WaitForWorkflowTemplateOperation(dataproc,
operation,
timeout_s=None,
poll_period_s=5):
"""Poll dataproc Operation until its status is done or timeout reached.
Args:
dataproc: wrapper for Dataproc messages, resources, and client
operation: Operation, message of the operation to be polled.
timeout_s: number, seconds to poll with retries before timing out.
poll_period_s: number, delay in seconds between requests.
Returns:
Operation: the return value of the last successful operations.get
request.
Raises:
OperationError: if the operation times out or finishes with an error.
"""
request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest(
name=operation.name)
log.status.Print('Waiting on operation [{0}].'.format(operation.name))
start_time = time.time()
operations = {'createCluster': None, 'deleteCluster': None}
status = {}
errors = {}
# If no timeout is specified, poll forever.
while timeout_s is None or timeout_s > (time.time() - start_time):
try:
operation = dataproc.client.projects_regions_operations.Get(request)
metadata = ParseOperationJsonMetadata(operation.metadata,
dataproc.messages.WorkflowMetadata)
PrintWorkflowMetadata(metadata, status, operations, errors)
if operation.done:
break
except apitools_exceptions.HttpError as http_exception:
# Do not retry on 4xx errors.
if IsClientHttpException(http_exception):
raise
time.sleep(poll_period_s)
metadata = ParseOperationJsonMetadata(operation.metadata,
dataproc.messages.WorkflowMetadata)
if not operation.done:
raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format(
operation.name))
elif operation.error:
raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
operation.name, FormatRpcError(operation.error)))
for op in ['createCluster', 'deleteCluster']:
if op in operations and operations[op] is not None and operations[op].error:
raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
operations[op].operationId, operations[op].error))
log.info('Operation [%s] finished after %.3f seconds', operation.name,
(time.time() - start_time))
return operation
class NoOpProgressDisplay(object):
"""For use in place of a ProgressTracker in a 'with' block."""
def __enter__(self):
pass
def __exit__(self, *unused_args):
pass
def WaitForJobTermination(dataproc,
job,
job_ref,
message,
goal_state,
error_state=None,
stream_driver_log=False,
log_poll_period_s=1,
dataproc_poll_period_s=10,
timeout_s=None):
"""Poll dataproc Job until its status is terminal or timeout reached.
Args:
dataproc: wrapper for dataproc resources, client and messages
job: The job to wait to finish.
job_ref: Parsed dataproc.projects.regions.jobs resource containing a
projectId, region, and jobId.
message: str, message to display to user while polling.
goal_state: JobStatus.StateValueValuesEnum, the state to define success
error_state: JobStatus.StateValueValuesEnum, the state to define failure
stream_driver_log: bool, Whether to show the Job's driver's output.
log_poll_period_s: number, delay in seconds between checking on the log.
dataproc_poll_period_s: number, delay in seconds between requests to the
Dataproc API.
timeout_s: number, time out for job completion. None means no timeout.
Returns:
Job: the return value of the last successful jobs.get request.
Raises:
JobError: if the job finishes with an error.
"""
request = dataproc.messages.DataprocProjectsRegionsJobsGetRequest(
projectId=job_ref.projectId, region=job_ref.region, jobId=job_ref.jobId)
driver_log_stream = None
last_job_poll_time = 0
job_complete = False
wait_display = None
driver_output_uri = None
def ReadDriverLogIfPresent():
if driver_log_stream and driver_log_stream.open:
# TODO(b/36049794): Don't read all output.
driver_log_stream.ReadIntoWritable(log.err)
def PrintEqualsLine():
attr = console_attr.GetConsoleAttr()
log.err.Print('=' * attr.GetTermSize()[0])
if stream_driver_log:
log.status.Print('Waiting for job output...')
wait_display = NoOpProgressDisplay()
else:
wait_display = progress_tracker.ProgressTracker(message, autotick=True)
start_time = now = time.time()
with wait_display:
while not timeout_s or timeout_s > (now - start_time):
# Poll logs first to see if it closed.
ReadDriverLogIfPresent()
log_stream_closed = driver_log_stream and not driver_log_stream.open
if (not job_complete and
job.status.state in dataproc.terminal_job_states):
job_complete = True
# Wait an 10s to get trailing output.
timeout_s = now - start_time + 10
if job_complete and (not stream_driver_log or log_stream_closed):
# Nothing left to wait for
break
regular_job_poll = (
not job_complete
# Poll less frequently on dataproc API
and now >= last_job_poll_time + dataproc_poll_period_s)
# Poll at regular frequency before output has streamed and after it has
# finished.
expecting_output_stream = stream_driver_log and not driver_log_stream
expecting_job_done = not job_complete and log_stream_closed
if regular_job_poll or expecting_output_stream or expecting_job_done:
last_job_poll_time = now
try:
job = dataproc.client.projects_regions_jobs.Get(request)
except apitools_exceptions.HttpError as error:
log.warning('GetJob failed:\n{}'.format(six.text_type(error)))
# Do not retry on 4xx errors.
if IsClientHttpException(error):
raise
if (stream_driver_log and job.driverOutputResourceUri and
job.driverOutputResourceUri != driver_output_uri):
if driver_output_uri:
PrintEqualsLine()
log.warning("Job attempt failed. Streaming new attempt's output.")
PrintEqualsLine()
driver_output_uri = job.driverOutputResourceUri
driver_log_stream = storage_helpers.StorageObjectSeriesStream(
job.driverOutputResourceUri)
time.sleep(log_poll_period_s)
now = time.time()
state = job.status.state
# goal_state and error_state will always be terminal
if state in dataproc.terminal_job_states:
if stream_driver_log:
if not driver_log_stream:
log.warning('Expected job output not found.')
elif driver_log_stream.open:
log.warning('Job terminated, but output did not finish streaming.')
if state is goal_state:
return job
if error_state and state is error_state:
if job.status.details:
raise exceptions.JobError('Job [{0}] failed with error:\n{1}'.format(
job_ref.jobId, job.status.details))
raise exceptions.JobError('Job [{0}] failed.'.format(job_ref.jobId))
if job.status.details:
log.info('Details:\n' + job.status.details)
raise exceptions.JobError(
'Job [{0}] entered state [{1}] while waiting for [{2}].'.format(
job_ref.jobId, state, goal_state))
raise exceptions.JobTimeoutError(
'Job [{0}] timed out while in state [{1}].'.format(job_ref.jobId, state))
# This replicates the fallthrough logic of flags._RegionAttributeConfig.
# It is necessary in cases like the --region flag where we are not parsing
# ResourceSpecs
def ResolveRegion():
return properties.VALUES.dataproc.region.GetOrFail()
# This replicates the fallthrough logic of flags._LocationAttributeConfig.
# It is necessary in cases like the --location flag where we are not parsing
# ResourceSpecs
def ResolveLocation():
return properties.VALUES.dataproc.location.GetOrFail()
# You probably want to use flags.AddClusterResourceArgument instead.
# If calling this method, you *must* have called flags.AddRegionFlag first to
# ensure a --region flag is stored into properties, which ResolveRegion
# depends on. This is also mutually incompatible with any usage of args.CONCEPTS
# which use --region as a resource attribute.
def ParseCluster(name, dataproc):
ref = dataproc.resources.Parse(
name,
params={
'region': ResolveRegion,
'projectId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.regions.clusters')
return ref
# You probably want to use flags.AddJobResourceArgument instead.
# If calling this method, you *must* have called flags.AddRegionFlag first to
# ensure a --region flag is stored into properties, which ResolveRegion
# depends on. This is also mutually incompatible with any usage of args.CONCEPTS
# which use --region as a resource attribute.
def ParseJob(job_id, dataproc):
ref = dataproc.resources.Parse(
job_id,
params={
'region': ResolveRegion,
'projectId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.regions.jobs')
return ref
def ParseOperationJsonMetadata(metadata_value, metadata_type):
"""Returns an Operation message for a metadata value."""
if not metadata_value:
return metadata_type()
return encoding.JsonToMessage(metadata_type,
encoding.MessageToJson(metadata_value))
# Used in bizarre scenarios where we want a qualified region rather than a
# short name
def ParseRegion(dataproc):
ref = dataproc.resources.Parse(
None,
params={
'regionId': ResolveRegion,
'projectId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.regions')
return ref
# Get dataproc.projects.locations resource
def ParseProjectsLocations(dataproc):
ref = dataproc.resources.Parse(
None,
params={
'locationsId': ResolveRegion,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.locations')
return ref
# Get dataproc.projects.locations resource
# This can be merged with ParseProjectsLocations() once we have migrated batches
# from `region` to `location`.
def ParseProjectsLocationsForSession(dataproc):
ref = dataproc.resources.Parse(
None,
params={
'locationsId': ResolveLocation(),
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.locations')
return ref
def ReadAutoscalingPolicy(dataproc, policy_id, policy_file_name=None):
"""Returns autoscaling policy read from YAML file.
Args:
dataproc: wrapper for dataproc resources, client and messages.
policy_id: The autoscaling policy id (last piece of the resource name).
policy_file_name: if set, location of the YAML file to read from. Otherwise,
reads from stdin.
Raises:
argparse.ArgumentError if duration formats are invalid or out of bounds.
"""
data = console_io.ReadFromFileOrStdin(policy_file_name or '-', binary=False)
policy = export_util.Import(
message_type=dataproc.messages.AutoscalingPolicy, stream=data)
# Ignore user set id in the file (if any), and overwrite with the policy_ref
# provided with this command
policy.id = policy_id
# Similarly, ignore the set resource name. This field is OUTPUT_ONLY, so we
# can just clear it.
policy.name = None
# Set duration fields to their seconds values
if policy.basicAlgorithm is not None:
if policy.basicAlgorithm.cooldownPeriod is not None:
policy.basicAlgorithm.cooldownPeriod = str(
arg_parsers.Duration(lower_bound='2m', upper_bound='1d')(
policy.basicAlgorithm.cooldownPeriod)) + 's'
if policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout is not None:
policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout = str(
arg_parsers.Duration(lower_bound='0s', upper_bound='1d')
(policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout)) + 's'
return policy
def CreateAutoscalingPolicy(dataproc, name, policy):
"""Returns the server-resolved policy after creating the given policy.
Args:
dataproc: wrapper for dataproc resources, client and messages.
name: The autoscaling policy resource name.
policy: The AutoscalingPolicy message to create.
"""
# TODO(b/109837200) make the dataproc discovery doc parameters consistent
# Parent() fails for the collection because of projectId/projectsId and
# regionId/regionsId inconsistencies.
# parent = template_ref.Parent().RelativePath()
parent = '/'.join(name.split('/')[0:4])
request = \
dataproc.messages.DataprocProjectsRegionsAutoscalingPoliciesCreateRequest(
parent=parent,
autoscalingPolicy=policy)
policy = dataproc.client.projects_regions_autoscalingPolicies.Create(request)
log.status.Print('Created [{0}].'.format(policy.id))
return policy
def UpdateAutoscalingPolicy(dataproc, name, policy):
"""Returns the server-resolved policy after updating the given policy.
Args:
dataproc: wrapper for dataproc resources, client and messages.
name: The autoscaling policy resource name.
policy: The AutoscalingPolicy message to create.
"""
# Though the name field is OUTPUT_ONLY in the API, the Update() method of the
# gcloud generated dataproc client expects it to be set.
policy.name = name
policy = \
dataproc.client.projects_regions_autoscalingPolicies.Update(policy)
log.status.Print('Updated [{0}].'.format(policy.id))
return policy
def _DownscopeCredentials(token, access_boundary_json):
"""Downscope the given credentials to the given access boundary.
Args:
token: The credentials to downscope.
access_boundary_json: The JSON-formatted access boundary.
Returns:
A downscopded credential with the given access-boundary.
"""
payload = {
'grant_type': 'urn:ietf:params:oauth:grant-type:token-exchange',
'requested_token_type': 'urn:ietf:params:oauth:token-type:access_token',
'subject_token_type': 'urn:ietf:params:oauth:token-type:access_token',
'subject_token': token,
'options': access_boundary_json
}
universe_domain = properties.VALUES.core.universe_domain.Get()
cab_token_url = f'https://sts.{universe_domain}/v1/token'
if properties.VALUES.context_aware.use_client_certificate.GetBool():
cab_token_url = f'https://sts.mtls.{universe_domain}/v1/token'
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
downscope_response = requests.GetSession().post(
cab_token_url, headers=headers, data=payload)
if downscope_response.status_code != 200:
raise ValueError('Error downscoping credentials')
cab_token = json.loads(downscope_response.content)
return cab_token.get('access_token', None)
def GetCredentials(access_boundary_json):
"""Get an access token for the user's current credentials.
Args:
access_boundary_json: JSON string holding the definition of the access
boundary to apply to the credentials.
Raises:
PersonalAuthError: If no access token could be fetched for the user.
Returns:
An access token for the user.
"""
cred = c_store.Load(
None, allow_account_impersonation=True, use_google_auth=True)
c_store.Refresh(cred)
if c_creds.IsOauth2ClientCredentials(cred):
token = cred.access_token
else:
token = cred.token
if not token:
raise exceptions.PersonalAuthError(
'No access token could be obtained from the current credentials.')
return _DownscopeCredentials(token, access_boundary_json)
class PersonalAuthUtils(object):
"""Util functions for enabling personal auth session."""
def __init__(self):
pass
def _RunOpensslCommand(self, openssl_executable, args, stdin=None):
"""Run the specified command, capturing and returning output as appropriate.
Args:
openssl_executable: The path to the openssl executable.
args: The arguments to the openssl command to run.
stdin: The input to the command.
Returns:
The output of the command.
Raises:
PersonalAuthError: If the call to openssl fails
"""
command = [openssl_executable]
command.extend(args)
stderr = None
try:
if getattr(subprocess, 'run', None):
proc = subprocess.run(
command,
input=stdin,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False)
stderr = proc.stderr.decode('utf-8').strip()
# N.B. It would be better if we could simply call `subprocess.run` with
# the `check` keyword arg set to true rather than manually calling
# `check_returncode`. However, we want to capture the stderr when the
# command fails, and the CalledProcessError type did not have a field
# for the stderr until Python version 3.5.
#
# As such, we need to manually call `check_returncode` as long as we
# are supporting Python versions prior to 3.5.
proc.check_returncode()
return proc.stdout
else:
p = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout, _ = p.communicate(input=stdin)
return stdout
except Exception as ex:
if stderr:
log.error('OpenSSL command "%s" failed with error message "%s"',
' '.join(command), stderr)
raise exceptions.PersonalAuthError('Failure running openssl command: "' +
' '.join(command) + '": ' +
six.text_type(ex))
def _ComputeHmac(self, key, data, openssl_executable):
"""Compute HMAC tag using OpenSSL."""
cmd_output = self._RunOpensslCommand(
openssl_executable, ['dgst', '-sha256', '-hmac', key],
stdin=data).decode('utf-8')
try:
# Split the openssl output to get the HMAC.
stripped_output = cmd_output.strip().split(' ')[1]
if len(stripped_output) != 64:
raise ValueError('HMAC output is expected to be 64 characters long.')
int(stripped_output, 16) # Check that the HMAC is in hex format.
except Exception as ex:
raise exceptions.PersonalAuthError(
'Failure due to invalid openssl output: ' + six.text_type(ex))
return (stripped_output + '\n').encode('utf-8')
def _DeriveHkdfKey(self, prk, info, openssl_executable):
"""Derives HMAC-based Key Derivation Function (HKDF) key through expansion on the initial pseudorandom key.
Args:
prk: a pseudorandom key.
info: optional context and application specific information (can be
empty).
openssl_executable: The path to the openssl executable.
Returns:
Output keying material, expected to be of 256-bit length.
"""
if len(prk) != 32:
raise ValueError(
'The given initial pseudorandom key is expected to be 32 bytes long.')
base16_prk = base64.b16encode(prk).decode('utf-8')
t1 = self._ComputeHmac(base16_prk, b'', openssl_executable)
t2data = bytearray(t1)
t2data.extend(info)
t2data.extend(b'\x01')
return self._ComputeHmac(base16_prk, t2data, openssl_executable)
# It is possible (although very rare) for the random pad generated by
# openssl to not be usable by openssl for encrypting the secret. When
# that happens the call to openssl will raise a CalledProcessError with
# the message "Error reading password from BIO\nError getting password".
#
# To account for this we retry on that error, but this is so rare that
# a single retry should be sufficient.
@retry.RetryOnException(max_retrials=1)
def _EncodeTokenUsingOpenssl(self, public_key, secret, openssl_executable):
"""Encode token using OpenSSL.
Args:
public_key: The public key for the session/cluster.
secret: Token to be encrypted.
openssl_executable: The path to the openssl executable.
Returns:
Encrypted token.
"""
key_hash = hashlib.sha256((public_key + '\n').encode('utf-8')).hexdigest()
iv_bytes = base64.b16encode(os.urandom(16))
initialization_vector = iv_bytes.decode('utf-8')
initial_key = os.urandom(32)
encryption_key = self._DeriveHkdfKey(initial_key,
'encryption_key'.encode('utf-8'),
openssl_executable)
auth_key = base64.b16encode(
self._DeriveHkdfKey(initial_key, 'auth_key'.encode('utf-8'),
openssl_executable)).decode('utf-8')
with tempfile.NamedTemporaryFile() as kf:
kf.write(public_key.encode('utf-8'))
kf.seek(0)
encrypted_key = self._RunOpensslCommand(
openssl_executable,
['rsautl', '-oaep', '-encrypt', '-pubin', '-inkey', kf.name],
stdin=base64.b64encode(initial_key))
if len(encrypted_key) != 512:
raise ValueError('The encrypted key is expected to be 512 bytes long.')
encoded_key = base64.b64encode(encrypted_key).decode('utf-8')
with tempfile.NamedTemporaryFile() as pf:
pf.write(encryption_key)
pf.seek(0)
encrypt_args = [
'enc', '-aes-256-ctr', '-salt', '-iv', initialization_vector, '-pass',
'file:{}'.format(pf.name)
]
encrypted_token = self._RunOpensslCommand(
openssl_executable, encrypt_args, stdin=secret.encode('utf-8'))
if len(encrypted_key) != 512:
raise ValueError('The encrypted key is expected to be 512 bytes long.')
encoded_token = base64.b64encode(encrypted_token).decode('utf-8')
hmac_input = bytearray(iv_bytes)
hmac_input.extend(encrypted_token)
hmac_tag = self._ComputeHmac(auth_key, hmac_input,
openssl_executable).decode('utf-8')[
0:32] # Truncate the HMAC tag to 128-bit
return '{}:{}:{}:{}:{}'.format(key_hash, encoded_token, encoded_key,
initialization_vector, hmac_tag)
def EncryptWithPublicKey(self, public_key, secret, openssl_executable):
"""Encrypt secret with resource public key.
Args:
public_key: The public key for the session/cluster.
secret: Token to be encrypted.
openssl_executable: The path to the openssl executable.
Returns:
Encrypted token.
"""
if openssl_executable:
return self._EncodeTokenUsingOpenssl(public_key, secret,
openssl_executable)
try:
# pylint: disable=g-import-not-at-top
import tink
from tink import hybrid
# pylint: enable=g-import-not-at-top
except ImportError:
raise exceptions.PersonalAuthError(
'Cannot load the Tink cryptography library. Either the '
'library is not installed, or site packages are not '
'enabled for the Google Cloud SDK. Please consult Cloud '
'Dataproc Personal Auth documentation on adding Tink to '
'Google Cloud SDK for further instructions.\n'
'https://cloud.google.com/dataproc/docs/concepts/iam/personal-auth')
hybrid.register()
context = b''
# Extract value of key corresponding to primary key.
public_key_value = json.loads(public_key)['key'][0]['keyData']['value']
key_hash = hashlib.sha256(
(public_key_value + '\n').encode('utf-8')).hexdigest()
# Load public key and create keyset handle.
reader = tink.JsonKeysetReader(public_key)
kh_pub = tink.read_no_secret_keyset_handle(reader)
# Create encrypter instance.
encrypter = kh_pub.primitive(hybrid.HybridEncrypt)
ciphertext = encrypter.encrypt(secret.encode('utf-8'), context)
encoded_token = base64.b64encode(ciphertext).decode('utf-8')
return '{}:{}'.format(key_hash, encoded_token)
def IsTinkLibraryInstalled(self):
"""Check if Tink cryptography library can be loaded."""
try:
# pylint: disable=g-import-not-at-top
# pylint: disable=unused-import
import tink
from tink import hybrid
# pylint: enable=g-import-not-at-top
# pylint: enable=unused-import
return True
except ImportError:
return False
def ReadSessionTemplate(dataproc, template_file_name=None):
"""Returns session template read from YAML file.
Args:
dataproc: Wrapper for dataproc resources, client and messages.
template_file_name: If set, location of the YAML file to read from.
Otherwise, reads from stdin.
Raises:
argparse.ArgumentError if duration formats are invalid or out of bounds.
"""
data = console_io.ReadFromFileOrStdin(template_file_name or '-', binary=False)
template = export_util.Import(
message_type=dataproc.messages.SessionTemplate, stream=data)
return template
def CreateSessionTemplate(dataproc, name, template):
"""Returns the server-resolved template after creating the given template.
Args:
dataproc: Wrapper for dataproc resources, client and messages.
name: The session template resource name.
template: The SessionTemplate message to create.
"""
parent = '/'.join(name.split('/')[0:4])
template.name = name
request = (
dataproc.messages.DataprocProjectsLocationsSessionTemplatesCreateRequest(
parent=parent,
sessionTemplate=template))
template = dataproc.client.projects_locations_sessionTemplates.Create(request)
log.status.Print('Created [{0}].'.format(template.name))
return template
def UpdateSessionTemplate(dataproc, name, template):
"""Returns the server-resolved template after updating the given template.
Args:
dataproc: Wrapper for dataproc resources, client and messages.
name: The session template resource name.
template: The SessionTemplate message to create.
"""
template.name = name
template = dataproc.client.projects_locations_sessionTemplates.Patch(template)
log.status.Print('Updated [{0}].'.format(template.name))
return template
def YieldFromListWithUnreachableList(unreachable_warning_msg, *args, **kwargs):
"""Yields from paged List calls handling unreachable list."""
unreachable = set()
def _GetFieldFn(message, attr):
unreachable.update(message.unreachable)
return getattr(message, attr)
result = list_pager.YieldFromList(get_field_func=_GetFieldFn, *args, **kwargs)
for item in result:
yield item
if unreachable:
log.warning(
unreachable_warning_msg,
', '.join(sorted(unreachable)),
)