# -*- coding: utf-8 -*- # # Copyright 2015 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common utilities for the gcloud dataproc tool.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals import base64 import hashlib import json import os import subprocess import tempfile import time import uuid from apitools.base.py import encoding from apitools.base.py import exceptions as apitools_exceptions from apitools.base.py import list_pager from googlecloudsdk.api_lib.dataproc import exceptions from googlecloudsdk.api_lib.dataproc import storage_helpers from googlecloudsdk.calliope import arg_parsers from googlecloudsdk.command_lib.export import util as export_util from googlecloudsdk.core import log from googlecloudsdk.core import properties from googlecloudsdk.core import requests from googlecloudsdk.core.console import console_attr from googlecloudsdk.core.console import console_io from googlecloudsdk.core.console import progress_tracker from googlecloudsdk.core.credentials import creds as c_creds from googlecloudsdk.core.credentials import store as c_store from googlecloudsdk.core.util import retry import six SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'schemas') def FormatRpcError(error): """Returns a printable representation of a failed Google API's status.proto. Args: error: the failed Status to print. Returns: A ready-to-print string representation of the error. """ log.debug('Error:\n' + encoding.MessageToJson(error)) return error.message def WaitForResourceDeletion(request_method, resource_ref, message, timeout_s=60, poll_period_s=5): """Poll Dataproc resource until it no longer exists.""" with progress_tracker.ProgressTracker(message, autotick=True): start_time = time.time() while timeout_s > (time.time() - start_time): try: request_method(resource_ref) except apitools_exceptions.HttpNotFoundError: # Object deleted return except apitools_exceptions.HttpError as error: log.debug('Get request for [{0}] failed:\n{1}', resource_ref, error) # Do not retry on 4xx errors if IsClientHttpException(error): raise time.sleep(poll_period_s) raise exceptions.OperationTimeoutError( 'Deleting resource [{0}] timed out.'.format(resource_ref)) def GetUniqueId(): return uuid.uuid4().hex class Bunch(object): """Class that converts a dictionary to javascript like object. For example: Bunch({'a': {'b': {'c': 0}}}).a.b.c == 0 """ def __init__(self, dictionary): for key, value in six.iteritems(dictionary): if isinstance(value, dict): value = Bunch(value) self.__dict__[key] = value def AddJvmDriverFlags(parser): parser.add_argument( '--jar', dest='main_jar', help='The HCFS URI of jar file containing the driver jar.') parser.add_argument( '--class', dest='main_class', help=('The class containing the main method of the driver. Must be in a' ' provided jar or jar that is already on the classpath')) def IsClientHttpException(http_exception): """Returns true if the http exception given is an HTTP 4xx error.""" return http_exception.status_code >= 400 and http_exception.status_code < 500 # TODO(b/36056506): Use api_lib.utils.waiter def WaitForOperation(dataproc, operation, message, timeout_s, poll_period_s=5): """Poll dataproc Operation until its status is done or timeout reached. Args: dataproc: wrapper for Dataproc messages, resources, and client operation: Operation, message of the operation to be polled. message: str, message to display to user while polling. timeout_s: number, seconds to poll with retries before timing out. poll_period_s: number, delay in seconds between requests. Returns: Operation: the return value of the last successful operations.get request. Raises: OperationError: if the operation times out or finishes with an error. """ request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest( name=operation.name) log.status.Print('Waiting on operation [{0}].'.format(operation.name)) start_time = time.time() warnings_so_far = 0 is_tty = console_io.IsInteractive(error=True) tracker_separator = '\n' if is_tty else '' def _LogWarnings(warnings): new_warnings = warnings[warnings_so_far:] if new_warnings: # Drop a line to print nicely with the progress tracker. log.err.write(tracker_separator) for warning in new_warnings: log.warning(warning) with progress_tracker.ProgressTracker(message, autotick=True): while timeout_s > (time.time() - start_time): try: operation = dataproc.client.projects_regions_operations.Get(request) metadata = ParseOperationJsonMetadata( operation.metadata, dataproc.messages.ClusterOperationMetadata) _LogWarnings(metadata.warnings) warnings_so_far = len(metadata.warnings) if operation.done: break except apitools_exceptions.HttpError as http_exception: # Do not retry on 4xx errors. if IsClientHttpException(http_exception): raise time.sleep(poll_period_s) metadata = ParseOperationJsonMetadata( operation.metadata, dataproc.messages.ClusterOperationMetadata) _LogWarnings(metadata.warnings) if not operation.done: raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format( operation.name)) elif operation.error: raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format( operation.name, FormatRpcError(operation.error))) log.info('Operation [%s] finished after %.3f seconds', operation.name, (time.time() - start_time)) return operation def PrintWorkflowMetadata(metadata, status, operations, errors): """Print workflow and job status for the running workflow template. This method will detect any changes of state in the latest metadata and print all the new states in a workflow template. For example: Workflow template template-name RUNNING Creating cluster: Operation ID create-id. Job ID job-id-1 RUNNING Job ID job-id-1 COMPLETED Deleting cluster: Operation ID delete-id. Workflow template template-name DONE Args: metadata: Dataproc WorkflowMetadata message object, contains the latest states of a workflow template. status: Dictionary, stores all jobs' status in the current workflow template, as well as the status of the overarching workflow. operations: Dictionary, stores cluster operation status for the workflow template. errors: Dictionary, stores errors from the current workflow template. """ # Key chosen to avoid collision with job ids, which are at least 3 characters. template_key = 'wt' if template_key not in status or metadata.state != status[template_key]: if metadata.template is not None: log.status.Print('WorkflowTemplate [{0}] {1}'.format( metadata.template, metadata.state)) else: # Workflows instantiated inline do not store an id in their metadata. log.status.Print('WorkflowTemplate {0}'.format(metadata.state)) status[template_key] = metadata.state if metadata.createCluster != operations['createCluster']: if hasattr(metadata.createCluster, 'error') and metadata.createCluster.error is not None: log.status.Print(metadata.createCluster.error) elif hasattr(metadata.createCluster, 'done') and metadata.createCluster.done is not None: log.status.Print('Created cluster: {0}.'.format(metadata.clusterName)) elif hasattr( metadata.createCluster, 'operationId') and metadata.createCluster.operationId is not None: log.status.Print('Creating cluster: Operation ID [{0}].'.format( metadata.createCluster.operationId)) operations['createCluster'] = metadata.createCluster if hasattr(metadata.graph, 'nodes'): for node in metadata.graph.nodes: if not node.jobId: continue if node.jobId not in status or status[node.jobId] != node.state: log.status.Print('Job ID {0} {1}'.format(node.jobId, node.state)) status[node.jobId] = node.state if node.error and (node.jobId not in errors or errors[node.jobId] != node.error): log.status.Print('Job ID {0} error: {1}'.format(node.jobId, node.error)) errors[node.jobId] = node.error if metadata.deleteCluster != operations['deleteCluster']: if hasattr(metadata.deleteCluster, 'error') and metadata.deleteCluster.error is not None: log.status.Print(metadata.deleteCluster.error) elif hasattr(metadata.deleteCluster, 'done') and metadata.deleteCluster.done is not None: log.status.Print('Deleted cluster: {0}.'.format(metadata.clusterName)) elif hasattr( metadata.deleteCluster, 'operationId') and metadata.deleteCluster.operationId is not None: log.status.Print('Deleting cluster: Operation ID [{0}].'.format( metadata.deleteCluster.operationId)) operations['deleteCluster'] = metadata.deleteCluster # TODO(b/36056506): Use api_lib.utils.waiter def WaitForWorkflowTemplateOperation(dataproc, operation, timeout_s=None, poll_period_s=5): """Poll dataproc Operation until its status is done or timeout reached. Args: dataproc: wrapper for Dataproc messages, resources, and client operation: Operation, message of the operation to be polled. timeout_s: number, seconds to poll with retries before timing out. poll_period_s: number, delay in seconds between requests. Returns: Operation: the return value of the last successful operations.get request. Raises: OperationError: if the operation times out or finishes with an error. """ request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest( name=operation.name) log.status.Print('Waiting on operation [{0}].'.format(operation.name)) start_time = time.time() operations = {'createCluster': None, 'deleteCluster': None} status = {} errors = {} # If no timeout is specified, poll forever. while timeout_s is None or timeout_s > (time.time() - start_time): try: operation = dataproc.client.projects_regions_operations.Get(request) metadata = ParseOperationJsonMetadata(operation.metadata, dataproc.messages.WorkflowMetadata) PrintWorkflowMetadata(metadata, status, operations, errors) if operation.done: break except apitools_exceptions.HttpError as http_exception: # Do not retry on 4xx errors. if IsClientHttpException(http_exception): raise time.sleep(poll_period_s) metadata = ParseOperationJsonMetadata(operation.metadata, dataproc.messages.WorkflowMetadata) if not operation.done: raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format( operation.name)) elif operation.error: raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format( operation.name, FormatRpcError(operation.error))) for op in ['createCluster', 'deleteCluster']: if op in operations and operations[op] is not None and operations[op].error: raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format( operations[op].operationId, operations[op].error)) log.info('Operation [%s] finished after %.3f seconds', operation.name, (time.time() - start_time)) return operation class NoOpProgressDisplay(object): """For use in place of a ProgressTracker in a 'with' block.""" def __enter__(self): pass def __exit__(self, *unused_args): pass def WaitForJobTermination(dataproc, job, job_ref, message, goal_state, error_state=None, stream_driver_log=False, log_poll_period_s=1, dataproc_poll_period_s=10, timeout_s=None): """Poll dataproc Job until its status is terminal or timeout reached. Args: dataproc: wrapper for dataproc resources, client and messages job: The job to wait to finish. job_ref: Parsed dataproc.projects.regions.jobs resource containing a projectId, region, and jobId. message: str, message to display to user while polling. goal_state: JobStatus.StateValueValuesEnum, the state to define success error_state: JobStatus.StateValueValuesEnum, the state to define failure stream_driver_log: bool, Whether to show the Job's driver's output. log_poll_period_s: number, delay in seconds between checking on the log. dataproc_poll_period_s: number, delay in seconds between requests to the Dataproc API. timeout_s: number, time out for job completion. None means no timeout. Returns: Job: the return value of the last successful jobs.get request. Raises: JobError: if the job finishes with an error. """ request = dataproc.messages.DataprocProjectsRegionsJobsGetRequest( projectId=job_ref.projectId, region=job_ref.region, jobId=job_ref.jobId) driver_log_stream = None last_job_poll_time = 0 job_complete = False wait_display = None driver_output_uri = None def ReadDriverLogIfPresent(): if driver_log_stream and driver_log_stream.open: # TODO(b/36049794): Don't read all output. driver_log_stream.ReadIntoWritable(log.err) def PrintEqualsLine(): attr = console_attr.GetConsoleAttr() log.err.Print('=' * attr.GetTermSize()[0]) if stream_driver_log: log.status.Print('Waiting for job output...') wait_display = NoOpProgressDisplay() else: wait_display = progress_tracker.ProgressTracker(message, autotick=True) start_time = now = time.time() with wait_display: while not timeout_s or timeout_s > (now - start_time): # Poll logs first to see if it closed. ReadDriverLogIfPresent() log_stream_closed = driver_log_stream and not driver_log_stream.open if (not job_complete and job.status.state in dataproc.terminal_job_states): job_complete = True # Wait an 10s to get trailing output. timeout_s = now - start_time + 10 if job_complete and (not stream_driver_log or log_stream_closed): # Nothing left to wait for break regular_job_poll = ( not job_complete # Poll less frequently on dataproc API and now >= last_job_poll_time + dataproc_poll_period_s) # Poll at regular frequency before output has streamed and after it has # finished. expecting_output_stream = stream_driver_log and not driver_log_stream expecting_job_done = not job_complete and log_stream_closed if regular_job_poll or expecting_output_stream or expecting_job_done: last_job_poll_time = now try: job = dataproc.client.projects_regions_jobs.Get(request) except apitools_exceptions.HttpError as error: log.warning('GetJob failed:\n{}'.format(six.text_type(error))) # Do not retry on 4xx errors. if IsClientHttpException(error): raise if (stream_driver_log and job.driverOutputResourceUri and job.driverOutputResourceUri != driver_output_uri): if driver_output_uri: PrintEqualsLine() log.warning("Job attempt failed. Streaming new attempt's output.") PrintEqualsLine() driver_output_uri = job.driverOutputResourceUri driver_log_stream = storage_helpers.StorageObjectSeriesStream( job.driverOutputResourceUri) time.sleep(log_poll_period_s) now = time.time() state = job.status.state # goal_state and error_state will always be terminal if state in dataproc.terminal_job_states: if stream_driver_log: if not driver_log_stream: log.warning('Expected job output not found.') elif driver_log_stream.open: log.warning('Job terminated, but output did not finish streaming.') if state is goal_state: return job if error_state and state is error_state: if job.status.details: raise exceptions.JobError('Job [{0}] failed with error:\n{1}'.format( job_ref.jobId, job.status.details)) raise exceptions.JobError('Job [{0}] failed.'.format(job_ref.jobId)) if job.status.details: log.info('Details:\n' + job.status.details) raise exceptions.JobError( 'Job [{0}] entered state [{1}] while waiting for [{2}].'.format( job_ref.jobId, state, goal_state)) raise exceptions.JobTimeoutError( 'Job [{0}] timed out while in state [{1}].'.format(job_ref.jobId, state)) # This replicates the fallthrough logic of flags._RegionAttributeConfig. # It is necessary in cases like the --region flag where we are not parsing # ResourceSpecs def ResolveRegion(): return properties.VALUES.dataproc.region.GetOrFail() # This replicates the fallthrough logic of flags._LocationAttributeConfig. # It is necessary in cases like the --location flag where we are not parsing # ResourceSpecs def ResolveLocation(): return properties.VALUES.dataproc.location.GetOrFail() # You probably want to use flags.AddClusterResourceArgument instead. # If calling this method, you *must* have called flags.AddRegionFlag first to # ensure a --region flag is stored into properties, which ResolveRegion # depends on. This is also mutually incompatible with any usage of args.CONCEPTS # which use --region as a resource attribute. def ParseCluster(name, dataproc): ref = dataproc.resources.Parse( name, params={ 'region': ResolveRegion, 'projectId': properties.VALUES.core.project.GetOrFail }, collection='dataproc.projects.regions.clusters') return ref # You probably want to use flags.AddJobResourceArgument instead. # If calling this method, you *must* have called flags.AddRegionFlag first to # ensure a --region flag is stored into properties, which ResolveRegion # depends on. This is also mutually incompatible with any usage of args.CONCEPTS # which use --region as a resource attribute. def ParseJob(job_id, dataproc): ref = dataproc.resources.Parse( job_id, params={ 'region': ResolveRegion, 'projectId': properties.VALUES.core.project.GetOrFail }, collection='dataproc.projects.regions.jobs') return ref def ParseOperationJsonMetadata(metadata_value, metadata_type): """Returns an Operation message for a metadata value.""" if not metadata_value: return metadata_type() return encoding.JsonToMessage(metadata_type, encoding.MessageToJson(metadata_value)) # Used in bizarre scenarios where we want a qualified region rather than a # short name def ParseRegion(dataproc): ref = dataproc.resources.Parse( None, params={ 'regionId': ResolveRegion, 'projectId': properties.VALUES.core.project.GetOrFail }, collection='dataproc.projects.regions') return ref # Get dataproc.projects.locations resource def ParseProjectsLocations(dataproc): ref = dataproc.resources.Parse( None, params={ 'locationsId': ResolveRegion, 'projectsId': properties.VALUES.core.project.GetOrFail }, collection='dataproc.projects.locations') return ref # Get dataproc.projects.locations resource # This can be merged with ParseProjectsLocations() once we have migrated batches # from `region` to `location`. def ParseProjectsLocationsForSession(dataproc): ref = dataproc.resources.Parse( None, params={ 'locationsId': ResolveLocation(), 'projectsId': properties.VALUES.core.project.GetOrFail }, collection='dataproc.projects.locations') return ref def ReadAutoscalingPolicy(dataproc, policy_id, policy_file_name=None): """Returns autoscaling policy read from YAML file. Args: dataproc: wrapper for dataproc resources, client and messages. policy_id: The autoscaling policy id (last piece of the resource name). policy_file_name: if set, location of the YAML file to read from. Otherwise, reads from stdin. Raises: argparse.ArgumentError if duration formats are invalid or out of bounds. """ data = console_io.ReadFromFileOrStdin(policy_file_name or '-', binary=False) policy = export_util.Import( message_type=dataproc.messages.AutoscalingPolicy, stream=data) # Ignore user set id in the file (if any), and overwrite with the policy_ref # provided with this command policy.id = policy_id # Similarly, ignore the set resource name. This field is OUTPUT_ONLY, so we # can just clear it. policy.name = None # Set duration fields to their seconds values if policy.basicAlgorithm is not None: if policy.basicAlgorithm.cooldownPeriod is not None: policy.basicAlgorithm.cooldownPeriod = str( arg_parsers.Duration(lower_bound='2m', upper_bound='1d')( policy.basicAlgorithm.cooldownPeriod)) + 's' if policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout is not None: policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout = str( arg_parsers.Duration(lower_bound='0s', upper_bound='1d') (policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout)) + 's' return policy def CreateAutoscalingPolicy(dataproc, name, policy): """Returns the server-resolved policy after creating the given policy. Args: dataproc: wrapper for dataproc resources, client and messages. name: The autoscaling policy resource name. policy: The AutoscalingPolicy message to create. """ # TODO(b/109837200) make the dataproc discovery doc parameters consistent # Parent() fails for the collection because of projectId/projectsId and # regionId/regionsId inconsistencies. # parent = template_ref.Parent().RelativePath() parent = '/'.join(name.split('/')[0:4]) request = \ dataproc.messages.DataprocProjectsRegionsAutoscalingPoliciesCreateRequest( parent=parent, autoscalingPolicy=policy) policy = dataproc.client.projects_regions_autoscalingPolicies.Create(request) log.status.Print('Created [{0}].'.format(policy.id)) return policy def UpdateAutoscalingPolicy(dataproc, name, policy): """Returns the server-resolved policy after updating the given policy. Args: dataproc: wrapper for dataproc resources, client and messages. name: The autoscaling policy resource name. policy: The AutoscalingPolicy message to create. """ # Though the name field is OUTPUT_ONLY in the API, the Update() method of the # gcloud generated dataproc client expects it to be set. policy.name = name policy = \ dataproc.client.projects_regions_autoscalingPolicies.Update(policy) log.status.Print('Updated [{0}].'.format(policy.id)) return policy def _DownscopeCredentials(token, access_boundary_json): """Downscope the given credentials to the given access boundary. Args: token: The credentials to downscope. access_boundary_json: The JSON-formatted access boundary. Returns: A downscopded credential with the given access-boundary. """ payload = { 'grant_type': 'urn:ietf:params:oauth:grant-type:token-exchange', 'requested_token_type': 'urn:ietf:params:oauth:token-type:access_token', 'subject_token_type': 'urn:ietf:params:oauth:token-type:access_token', 'subject_token': token, 'options': access_boundary_json } universe_domain = properties.VALUES.core.universe_domain.Get() cab_token_url = f'https://sts.{universe_domain}/v1/token' if properties.VALUES.context_aware.use_client_certificate.GetBool(): cab_token_url = f'https://sts.mtls.{universe_domain}/v1/token' headers = {'Content-Type': 'application/x-www-form-urlencoded'} downscope_response = requests.GetSession().post( cab_token_url, headers=headers, data=payload) if downscope_response.status_code != 200: raise ValueError('Error downscoping credentials') cab_token = json.loads(downscope_response.content) return cab_token.get('access_token', None) def GetCredentials(access_boundary_json): """Get an access token for the user's current credentials. Args: access_boundary_json: JSON string holding the definition of the access boundary to apply to the credentials. Raises: PersonalAuthError: If no access token could be fetched for the user. Returns: An access token for the user. """ cred = c_store.Load( None, allow_account_impersonation=True, use_google_auth=True) c_store.Refresh(cred) if c_creds.IsOauth2ClientCredentials(cred): token = cred.access_token else: token = cred.token if not token: raise exceptions.PersonalAuthError( 'No access token could be obtained from the current credentials.') return _DownscopeCredentials(token, access_boundary_json) class PersonalAuthUtils(object): """Util functions for enabling personal auth session.""" def __init__(self): pass def _RunOpensslCommand(self, openssl_executable, args, stdin=None): """Run the specified command, capturing and returning output as appropriate. Args: openssl_executable: The path to the openssl executable. args: The arguments to the openssl command to run. stdin: The input to the command. Returns: The output of the command. Raises: PersonalAuthError: If the call to openssl fails """ command = [openssl_executable] command.extend(args) stderr = None try: if getattr(subprocess, 'run', None): proc = subprocess.run( command, input=stdin, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) stderr = proc.stderr.decode('utf-8').strip() # N.B. It would be better if we could simply call `subprocess.run` with # the `check` keyword arg set to true rather than manually calling # `check_returncode`. However, we want to capture the stderr when the # command fails, and the CalledProcessError type did not have a field # for the stderr until Python version 3.5. # # As such, we need to manually call `check_returncode` as long as we # are supporting Python versions prior to 3.5. proc.check_returncode() return proc.stdout else: p = subprocess.Popen( command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, _ = p.communicate(input=stdin) return stdout except Exception as ex: if stderr: log.error('OpenSSL command "%s" failed with error message "%s"', ' '.join(command), stderr) raise exceptions.PersonalAuthError('Failure running openssl command: "' + ' '.join(command) + '": ' + six.text_type(ex)) def _ComputeHmac(self, key, data, openssl_executable): """Compute HMAC tag using OpenSSL.""" cmd_output = self._RunOpensslCommand( openssl_executable, ['dgst', '-sha256', '-hmac', key], stdin=data).decode('utf-8') try: # Split the openssl output to get the HMAC. stripped_output = cmd_output.strip().split(' ')[1] if len(stripped_output) != 64: raise ValueError('HMAC output is expected to be 64 characters long.') int(stripped_output, 16) # Check that the HMAC is in hex format. except Exception as ex: raise exceptions.PersonalAuthError( 'Failure due to invalid openssl output: ' + six.text_type(ex)) return (stripped_output + '\n').encode('utf-8') def _DeriveHkdfKey(self, prk, info, openssl_executable): """Derives HMAC-based Key Derivation Function (HKDF) key through expansion on the initial pseudorandom key. Args: prk: a pseudorandom key. info: optional context and application specific information (can be empty). openssl_executable: The path to the openssl executable. Returns: Output keying material, expected to be of 256-bit length. """ if len(prk) != 32: raise ValueError( 'The given initial pseudorandom key is expected to be 32 bytes long.') base16_prk = base64.b16encode(prk).decode('utf-8') t1 = self._ComputeHmac(base16_prk, b'', openssl_executable) t2data = bytearray(t1) t2data.extend(info) t2data.extend(b'\x01') return self._ComputeHmac(base16_prk, t2data, openssl_executable) # It is possible (although very rare) for the random pad generated by # openssl to not be usable by openssl for encrypting the secret. When # that happens the call to openssl will raise a CalledProcessError with # the message "Error reading password from BIO\nError getting password". # # To account for this we retry on that error, but this is so rare that # a single retry should be sufficient. @retry.RetryOnException(max_retrials=1) def _EncodeTokenUsingOpenssl(self, public_key, secret, openssl_executable): """Encode token using OpenSSL. Args: public_key: The public key for the session/cluster. secret: Token to be encrypted. openssl_executable: The path to the openssl executable. Returns: Encrypted token. """ key_hash = hashlib.sha256((public_key + '\n').encode('utf-8')).hexdigest() iv_bytes = base64.b16encode(os.urandom(16)) initialization_vector = iv_bytes.decode('utf-8') initial_key = os.urandom(32) encryption_key = self._DeriveHkdfKey(initial_key, 'encryption_key'.encode('utf-8'), openssl_executable) auth_key = base64.b16encode( self._DeriveHkdfKey(initial_key, 'auth_key'.encode('utf-8'), openssl_executable)).decode('utf-8') with tempfile.NamedTemporaryFile() as kf: kf.write(public_key.encode('utf-8')) kf.seek(0) encrypted_key = self._RunOpensslCommand( openssl_executable, ['rsautl', '-oaep', '-encrypt', '-pubin', '-inkey', kf.name], stdin=base64.b64encode(initial_key)) if len(encrypted_key) != 512: raise ValueError('The encrypted key is expected to be 512 bytes long.') encoded_key = base64.b64encode(encrypted_key).decode('utf-8') with tempfile.NamedTemporaryFile() as pf: pf.write(encryption_key) pf.seek(0) encrypt_args = [ 'enc', '-aes-256-ctr', '-salt', '-iv', initialization_vector, '-pass', 'file:{}'.format(pf.name) ] encrypted_token = self._RunOpensslCommand( openssl_executable, encrypt_args, stdin=secret.encode('utf-8')) if len(encrypted_key) != 512: raise ValueError('The encrypted key is expected to be 512 bytes long.') encoded_token = base64.b64encode(encrypted_token).decode('utf-8') hmac_input = bytearray(iv_bytes) hmac_input.extend(encrypted_token) hmac_tag = self._ComputeHmac(auth_key, hmac_input, openssl_executable).decode('utf-8')[ 0:32] # Truncate the HMAC tag to 128-bit return '{}:{}:{}:{}:{}'.format(key_hash, encoded_token, encoded_key, initialization_vector, hmac_tag) def EncryptWithPublicKey(self, public_key, secret, openssl_executable): """Encrypt secret with resource public key. Args: public_key: The public key for the session/cluster. secret: Token to be encrypted. openssl_executable: The path to the openssl executable. Returns: Encrypted token. """ if openssl_executable: return self._EncodeTokenUsingOpenssl(public_key, secret, openssl_executable) try: # pylint: disable=g-import-not-at-top import tink from tink import hybrid # pylint: enable=g-import-not-at-top except ImportError: raise exceptions.PersonalAuthError( 'Cannot load the Tink cryptography library. Either the ' 'library is not installed, or site packages are not ' 'enabled for the Google Cloud SDK. Please consult Cloud ' 'Dataproc Personal Auth documentation on adding Tink to ' 'Google Cloud SDK for further instructions.\n' 'https://cloud.google.com/dataproc/docs/concepts/iam/personal-auth') hybrid.register() context = b'' # Extract value of key corresponding to primary key. public_key_value = json.loads(public_key)['key'][0]['keyData']['value'] key_hash = hashlib.sha256( (public_key_value + '\n').encode('utf-8')).hexdigest() # Load public key and create keyset handle. reader = tink.JsonKeysetReader(public_key) kh_pub = tink.read_no_secret_keyset_handle(reader) # Create encrypter instance. encrypter = kh_pub.primitive(hybrid.HybridEncrypt) ciphertext = encrypter.encrypt(secret.encode('utf-8'), context) encoded_token = base64.b64encode(ciphertext).decode('utf-8') return '{}:{}'.format(key_hash, encoded_token) def IsTinkLibraryInstalled(self): """Check if Tink cryptography library can be loaded.""" try: # pylint: disable=g-import-not-at-top # pylint: disable=unused-import import tink from tink import hybrid # pylint: enable=g-import-not-at-top # pylint: enable=unused-import return True except ImportError: return False def ReadSessionTemplate(dataproc, template_file_name=None): """Returns session template read from YAML file. Args: dataproc: Wrapper for dataproc resources, client and messages. template_file_name: If set, location of the YAML file to read from. Otherwise, reads from stdin. Raises: argparse.ArgumentError if duration formats are invalid or out of bounds. """ data = console_io.ReadFromFileOrStdin(template_file_name or '-', binary=False) template = export_util.Import( message_type=dataproc.messages.SessionTemplate, stream=data) return template def CreateSessionTemplate(dataproc, name, template): """Returns the server-resolved template after creating the given template. Args: dataproc: Wrapper for dataproc resources, client and messages. name: The session template resource name. template: The SessionTemplate message to create. """ parent = '/'.join(name.split('/')[0:4]) template.name = name request = ( dataproc.messages.DataprocProjectsLocationsSessionTemplatesCreateRequest( parent=parent, sessionTemplate=template)) template = dataproc.client.projects_locations_sessionTemplates.Create(request) log.status.Print('Created [{0}].'.format(template.name)) return template def UpdateSessionTemplate(dataproc, name, template): """Returns the server-resolved template after updating the given template. Args: dataproc: Wrapper for dataproc resources, client and messages. name: The session template resource name. template: The SessionTemplate message to create. """ template.name = name template = dataproc.client.projects_locations_sessionTemplates.Patch(template) log.status.Print('Updated [{0}].'.format(template.name)) return template def YieldFromListWithUnreachableList(unreachable_warning_msg, *args, **kwargs): """Yields from paged List calls handling unreachable list.""" unreachable = set() def _GetFieldFn(message, attr): unreachable.update(message.unreachable) return getattr(message, attr) result = list_pager.YieldFromList(get_field_func=_GetFieldFn, *args, **kwargs) for item in result: yield item if unreachable: log.warning( unreachable_warning_msg, ', '.join(sorted(unreachable)), )