feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,83 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for creating an Anywhere Cache instance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class CreateAnywhereCacheTask(task.Task):
"""Creates an Anywhere Cache instance in particular zone of a bucket."""
def __init__(self, bucket_url, zone, admission_policy=None, ttl=None):
"""Initializes task.
Args:
bucket_url (CloudUrl): The URL of the bucket where the Anywhere Cache
should be created.
zone (str): Name of the zonal locations where the Anywhere Cache should be
created.
admission_policy (str|None): The cache admission policy decides for each
cache miss, that is whether to insert the missed block or not.
ttl (str|None): Cache entry time-to-live in seconds
"""
super(CreateAnywhereCacheTask, self).__init__()
self._bucket_url = bucket_url
self._zone = zone
self._admission_policy = admission_policy
self._ttl = ttl
self.parallel_processing_key = '{}/{}'.format(bucket_url.bucket_name, zone)
def execute(self, task_status_queue=None):
log.status.Print(
'Creating a cache instance for bucket {} in zone {}...'.format(
self._bucket_url, self._zone
)
)
provider = self._bucket_url.scheme
api_client = api_factory.get_api(provider)
response = api_client.create_anywhere_cache(
self._bucket_url.bucket_name,
self._zone,
admission_policy=self._admission_policy,
ttl=self._ttl,
)
log.status.Print(
'Initiated the operation id: {} for creating a cache instance for'
' bucket {} in zone {}...'.format(
response.name, self._bucket_url, self._zone
)
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, CreateAnywhereCacheTask):
return NotImplemented
return (
self._bucket_url == other._bucket_url
and self._zone == other._zone
and self._admission_policy == other._admission_policy
and self._ttl == other._ttl
)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for disabling an Anywhere Cache instance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class DisableAnywhereCacheTask(task.Task):
"""Task for disabling an Anywhere Cache instance."""
def __init__(self, bucket_name, anywhere_cache_id):
"""Initializes task."""
super(DisableAnywhereCacheTask, self).__init__()
self._bucket_name = bucket_name
self._anywhere_cache_id = anywhere_cache_id
self.parallel_processing_key = '{}/{}'.format(
bucket_name, anywhere_cache_id
)
def execute(self, task_status_queue=None):
log.status.Print(
'Requesting to disable a cache instance of bucket gs://{} having'
' anywhere_cache_id {}'.format(
self._bucket_name, self._anywhere_cache_id
)
)
provider = storage_url.ProviderPrefix.GCS
api_factory.get_api(provider).disable_anywhere_cache(
self._bucket_name,
self._anywhere_cache_id,
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, DisableAnywhereCacheTask):
return NotImplemented
return (
self._bucket_name == other._bucket_name
and self._anywhere_cache_id == other._anywhere_cache_id
)

View File

@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for updating an Anywhere Cache instance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class PatchAnywhereCacheTask(task.Task):
"""Updates an Anywhere Cache instance."""
def __init__(
self, bucket_name, anywhere_cache_id, admission_policy=None, ttl=None
):
"""Initializes task.
Args:
bucket_name (str): The name of the bucket where the Anywhere Cache should
be updated.
anywhere_cache_id (str): Name of the zonal location where the Anywhere
Cache should be updated.
admission_policy (str|None): The cache admission policy decides for each
cache miss, that is whether to insert the missed block or not.
ttl (str|None): Cache entry time-to-live in seconds
"""
super(PatchAnywhereCacheTask, self).__init__()
self._bucket_name = bucket_name
self._anywhere_cache_id = anywhere_cache_id
self._admission_policy = admission_policy
self._ttl = ttl
self.parallel_processing_key = '{}/{}'.format(
bucket_name, anywhere_cache_id
)
def execute(self, task_status_queue=None):
log.status.Print(
'Updating a cache instance of bucket gs://{} having'
' anywhere_cache_id {}'.format(
self._bucket_name, self._anywhere_cache_id
)
)
provider = storage_url.ProviderPrefix.GCS
response = api_factory.get_api(provider).patch_anywhere_cache(
self._bucket_name,
self._anywhere_cache_id,
admission_policy=self._admission_policy,
ttl=self._ttl,
)
log.status.Print(
'Initiated the operation id: {} for updating a cache instance of bucket'
' gs://{} having anywhere_cache_id {}'.format(
response.name, self._bucket_name, self._anywhere_cache_id
)
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, PatchAnywhereCacheTask):
return NotImplemented
return (
self._bucket_name == other._bucket_name
and self._anywhere_cache_id == other._anywhere_cache_id
and self._admission_policy == other._admission_policy
and self._ttl == other._ttl
)

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for pausing an Anywhere Cache instance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class PauseAnywhereCacheTask(task.Task):
"""Task for pausing an Anywhere Cache instance."""
def __init__(self, bucket_name, zone):
"""Initializes task."""
super(PauseAnywhereCacheTask, self).__init__()
self._bucket_name = bucket_name
self._anywhere_cache_id = zone
self.parallel_processing_key = '{}/{}'.format(bucket_name, zone)
def execute(self, task_status_queue=None):
log.status.Print(
'Requesting to pause a cache instance of bucket gs://{} having'
' anywhere_cache_id {}'.format(
self._bucket_name, self._anywhere_cache_id
)
)
provider = storage_url.ProviderPrefix.GCS
api_factory.get_api(provider).pause_anywhere_cache(
self._bucket_name,
self._anywhere_cache_id,
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, PauseAnywhereCacheTask):
return NotImplemented
return (
self._bucket_name == other._bucket_name
and self._anywhere_cache_id == other._anywhere_cache_id
)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for resuming an Anywhere Cache instance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class ResumeAnywhereCacheTask(task.Task):
"""Task for resuming an Anywhere Cache instance."""
def __init__(self, bucket_name, anywhere_cache_id):
"""Initializes task."""
super(ResumeAnywhereCacheTask, self).__init__()
self._bucket_name = bucket_name
self._anywhere_cache_id = anywhere_cache_id
self.parallel_processing_key = '{}/{}'.format(
bucket_name, anywhere_cache_id
)
def execute(self, task_status_queue=None):
log.status.Print(
'Requesting to resume a cache instance of bucket gs://{} having'
' anywhere_cache_id: {}'.format(
self._bucket_name, self._anywhere_cache_id
)
)
provider = storage_url.ProviderPrefix.GCS
api_factory.get_api(provider).resume_anywhere_cache(
self._bucket_name,
self._anywhere_cache_id,
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, ResumeAnywhereCacheTask):
return NotImplemented
return (
self._bucket_name == other._bucket_name
and self._anywhere_cache_id == other._anywhere_cache_id
)

View File

@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for creating a bucket."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class CreateBucketTask(task.Task):
"""Creates a cloud storage bucket."""
def __init__(self, bucket_resource, user_request_args=None):
"""Initializes task.
Args:
bucket_resource (resource_reference.BucketResource): Should contain
desired metadata for bucket.
user_request_args (UserRequestArgs|None): Values for request config.
"""
super(CreateBucketTask, self).__init__()
self._bucket_resource = bucket_resource
self._user_request_args = user_request_args
def execute(self, task_status_queue=None):
log.status.Print('Creating {}...'.format(self._bucket_resource))
provider = self._bucket_resource.storage_url.scheme
request_config = request_config_factory.get_request_config(
self._bucket_resource.storage_url,
user_request_args=self._user_request_args)
api_factory.get_api(provider).create_bucket(
self._bucket_resource, request_config=request_config)
def __eq__(self, other):
if not isinstance(other, CreateBucketTask):
return NotImplemented
return (
self._bucket_resource == other._bucket_resource
and self._user_request_args == other._user_request_args
)

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for deleting a notification configuration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
class DeleteNotificationConfigurationTask(task.Task):
"""Deletes a notification configuration."""
def __init__(self, bucket_url, notification_id):
"""Initializes task.
Args:
bucket_url (storage_url.CloudUrl): URL of bucket that notification
configuration exists on.
notification_id (str): Name of the notification configuration (integer as
string).
"""
super(__class__, self).__init__()
self._bucket_url = bucket_url
self._notification_id = notification_id
self.parallel_processing_key = bucket_url.url_string + '|' + notification_id
def execute(self, task_status_queue=None):
provider = self._bucket_url.scheme
api_factory.get_api(provider).delete_notification_configuration(
self._bucket_url, self._notification_id)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, DeleteNotificationConfigurationTask):
return NotImplemented
return self.parallel_processing_key == other.parallel_processing_key

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for restoring a soft-deleted bucket."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class RestoreBucketTask(task.Task):
"""Restores a soft-deleted cloud storage bucket."""
def __init__(self, bucket_url):
"""Initializes task.
Args:
bucket_url (CloudUrl): Bucket Url to restore.
"""
super(RestoreBucketTask, self).__init__()
self._bucket_url = bucket_url
def execute(self, task_status_queue=None):
log.status.Print('Restoring {}...'.format(self._bucket_url.url_string))
provider = self._bucket_url.scheme
api_factory.get_api(provider).restore_bucket(self._bucket_url)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return self._bucket_url == other._bucket_url

View File

@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for updating a bucket."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import errors
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.artifacts import requests
from googlecloudsdk.command_lib.storage import errors as command_errors
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
from googlecloudsdk.core.console import console_io
class UpdateBucketTask(task.Task):
"""Updates a cloud storage bucket's metadata."""
def __init__(self, bucket_resource, user_request_args=None):
"""Initializes task.
Args:
bucket_resource (BucketResource|UnknownResource): The bucket to update.
user_request_args (UserRequestArgs|None): Describes metadata updates to
perform.
"""
super(UpdateBucketTask, self).__init__()
self._bucket_resource = bucket_resource
self._user_request_args = user_request_args
def __eq__(self, other):
if not isinstance(other, UpdateBucketTask):
return NotImplemented
return (self._bucket_resource == other._bucket_resource and
self._user_request_args == other._user_request_args)
def _confirm_and_lock_retention_policy(self, api_client, bucket_resource,
request_config):
"""Locks a buckets retention policy if possible and the user confirms.
Args:
api_client (cloud_api.CloudApi): API client that should issue the lock
request.
bucket_resource (BucketResource): Metadata of the bucket containing the
retention policy to lock.
request_config (request_config_factory._RequestConfig): Contains
additional request parameters.
"""
lock_prompt = (
'This will permanently set the retention policy on "{}" to the'
' following:\n\n{}\n\nThis setting cannot be reverted. Continue? '
).format(self._bucket_resource, bucket_resource.retention_policy)
if not bucket_resource.retention_policy:
raise command_errors.Error(
'Bucket "{}" does not have a retention policy.'.format(
self._bucket_resource))
elif bucket_resource.retention_policy_is_locked:
log.error('Retention policy on "{}" is already locked.'.format(
self._bucket_resource))
elif console_io.PromptContinue(message=lock_prompt, default=False):
log.status.Print('Locking retention policy on {}...'.format(
self._bucket_resource))
api_client.lock_bucket_retention_policy(bucket_resource, request_config)
else:
# Gsutil does not update the exit code here, so we cannot use
# cancel_or_no with PromptContinue.
log.error('Abort locking retention policy on "{}".'.format(
self._bucket_resource))
def execute(self, task_status_queue=None):
log.status.Print('Updating {}...'.format(self._bucket_resource))
request_config = request_config_factory.get_request_config(
self._bucket_resource.storage_url,
user_request_args=self._user_request_args)
provider = self._bucket_resource.storage_url.scheme
api_client = api_factory.get_api(provider)
try:
bucket_metadata = api_client.patch_bucket(
self._bucket_resource, request_config=request_config)
except errors.GcsApiError as e:
# Service agent does not have the encrypter/decrypter role.
if (e.payload.status_code == 403 and
request_config.resource_args.default_encryption_key):
service_agent = api_client.get_service_agent()
requests.AddCryptoKeyPermission(
request_config.resource_args.default_encryption_key,
'serviceAccount:' + service_agent)
bucket_metadata = api_client.patch_bucket(
self._bucket_resource, request_config=request_config)
else:
raise
if getattr(
request_config.resource_args, 'retention_period_to_be_locked', None):
self._confirm_and_lock_retention_policy(
api_client, bucket_metadata, request_config)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of CatTaskIterator for calling the StreamingDownloadTask."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.command_lib.storage.tasks.cp import streaming_download_task
def _get_start_byte(start_byte, source_resource_size):
"""Returns the byte index to start streaming from.
Gets an absolute start byte for object download API calls.
Args:
start_byte (int): The start index entered by the user. Negative values are
interpreted as offsets from the end of the object.
source_resource_size (int|None): The size of the source resource.
Returns:
int: The byte index to start the object download from.
"""
if start_byte < 0:
if abs(start_byte) >= source_resource_size:
return 0
return source_resource_size - abs(start_byte)
return start_byte
def get_cat_task_iterator(source_iterator, show_url, start_byte, end_byte):
"""An iterator that yields StreamingDownloadTasks for cat sources.
Given a list of strings that are object URLs ("gs://foo/object1"), yield a
StreamingDownloadTask.
Args:
source_iterator (NameExpansionIterator): Yields sources resources that
should be packaged in StreamingDownloadTasks.
show_url (bool): Says whether or not to print the header before each
object's content.
start_byte (int): The byte index to start streaming from.
end_byte (int|None): The byte index to stop streaming from.
Yields:
StreamingDownloadTask
"""
stdout = os.fdopen(1, 'wb')
dummy_destination_resource = resource_reference.FileObjectResource(
storage_url.FileUrl('-')
)
for item in source_iterator:
yield streaming_download_task.StreamingDownloadTask(
item.resource,
dummy_destination_resource,
download_stream=stdout,
show_url=show_url,
start_byte=_get_start_byte(start_byte, item.resource.size),
end_byte=end_byte,
)

View File

@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for composing storage objects."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import errors as command_errors
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class ComposeObjectsTask(task.Task):
"""Composes storage objects."""
def __init__(
self,
source_resources,
destination_resource,
original_source_resource=None,
posix_to_set=None,
print_status_message=False,
user_request_args=None,
):
"""Initializes task.
Args:
source_resources (list[ObjectResource|UnknownResource]): The objects to
compose. This field accepts UnknownResources since it should allow
ComposeObjectsTasks to be initialized before the target objects have
been created.
destination_resource (resource_reference.UnknownResource): Metadata for
the resulting composite object.
original_source_resource (Resource|None): Useful for finding metadata to
apply to final object. For instance, if doing a composite upload, this
would represent the pre-split local file.
posix_to_set (PosixAttributes|None): POSIX info set as custom cloud
metadata on target. If preserving POSIX, avoids re-parsing metadata from
file system.
print_status_message (bool): If True, the task prints the status message.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
"""
super(ComposeObjectsTask, self).__init__()
self._source_resources = source_resources
self._destination_resource = destination_resource
self._original_source_resource = original_source_resource
self._posix_to_set = posix_to_set
self._print_status_message = print_status_message
self._user_request_args = user_request_args
def execute(self, task_status_queue=None):
del task_status_queue # Unused.
provider = self._destination_resource.storage_url.scheme
api = api_factory.get_api(provider)
if cloud_api.Capability.COMPOSE_OBJECTS not in api.capabilities:
raise command_errors.Error(
'Compose is not available with requested provider: {}'.format(
provider))
for source_resource in self._source_resources:
if source_resource.storage_url.bucket_name != self._destination_resource.storage_url.bucket_name:
raise command_errors.Error(
'Inter-bucket composing not supported')
request_config = request_config_factory.get_request_config(
self._destination_resource.storage_url,
user_request_args=self._user_request_args)
if self._print_status_message:
log.status.write('Composing {} from {} component object(s).\n'.format(
self._destination_resource, len(self._source_resources)))
created_resource = api.compose_objects(
self._source_resources,
self._destination_resource,
request_config,
original_source_resource=self._original_source_resource,
posix_to_set=self._posix_to_set,
)
return task.Output(
messages=[
task.Message(
topic=task.Topic.CREATED_RESOURCE, payload=created_resource),
],
additional_task_iterators=[])
def __eq__(self, other):
if not isinstance(other, ComposeObjectsTask):
return NotImplemented
return (self._source_resources == other._source_resources and
self._destination_resource == other._destination_resource and
self._user_request_args == other._user_request_args)

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for components in copy operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import hashlib
import math
import os
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import files
from googlecloudsdk.core.util import scaled_integer
_PARALLEL_UPLOAD_STATIC_PREFIX = """
PARALLEL_UPLOAD_SALT_TO_PREVENT_COLLISIONS.
The theory is that no user will have prepended this to the front of
one of their object names and then do an MD5 hash of the name, and
then prepended PARALLEL_UPLOAD_TEMP_NAMESPACE to the front of their object
name. Note that there will be no problems with object name length since we
hash the original name.
"""
_PARALLEL_UPLOAD_TEMPORARY_NAMESPACE = (
'gcloud/tmp/parallel_composite_uploads/'
'see_gcloud_storage_cp_help_for_details/')
def _ensure_truthy_path_ends_with_single_delimiter(string, delimiter):
if not string:
return ''
return string.rstrip(delimiter) + delimiter
def _get_temporary_component_name(
source_resource, destination_resource, random_prefix, component_id
):
"""Gets a temporary object name for a component of source_resource."""
source_name = source_resource.storage_url.resource_name
salted_name = _PARALLEL_UPLOAD_STATIC_PREFIX + source_name
sha1_hash = hashlib.sha1(salted_name.encode('utf-8'))
component_prefix = (
properties.VALUES.storage.parallel_composite_upload_component_prefix.Get()
)
delimiter = destination_resource.storage_url.delimiter
if component_prefix.startswith(delimiter):
prefix = component_prefix.lstrip(delimiter)
else:
destination_object_name = destination_resource.storage_url.resource_name
destination_prefix, _, _ = destination_object_name.rpartition(delimiter)
prefix = (
_ensure_truthy_path_ends_with_single_delimiter(
destination_prefix, delimiter
)
+ component_prefix
)
return '{}{}_{}_{}'.format(
_ensure_truthy_path_ends_with_single_delimiter(prefix, delimiter),
random_prefix,
sha1_hash.hexdigest(),
str(component_id),
)
def create_file_if_needed(source_resource, destination_resource):
"""Creates new file if none exists or one that is too large exists at path.
Args:
source_resource (ObjectResource): Contains size metadata for target file.
destination_resource(FileObjectResource|UnknownResource): Contains path to
create file at.
"""
file_path = destination_resource.storage_url.resource_name
if os.path.exists(
file_path) and os.path.getsize(file_path) <= source_resource.size:
return
with files.BinaryFileWriter(
file_path,
create_path=True,
mode=files.BinaryFileWriterMode.TRUNCATE,
convert_invalid_windows_characters=properties.VALUES.storage
.convert_incompatible_windows_path_characters.GetBool()):
# Wipe or create file.
pass
def get_temporary_component_resource(source_resource, destination_resource,
random_prefix, component_id):
"""Gets a temporary component destination resource for a composite upload.
Args:
source_resource (resource_reference.FileObjectResource): The upload source.
destination_resource (resource_reference.ObjectResource|UnknownResource):
The upload destination.
random_prefix (str): Added to temporary component names to avoid collisions
between different instances of the CLI uploading to the same destination.
component_id (int): An id that's not shared by any other component in this
transfer.
Returns:
A resource_reference.UnknownResource representing the component's
destination.
"""
component_object_name = _get_temporary_component_name(
source_resource, destination_resource, random_prefix, component_id
)
destination_url = destination_resource.storage_url
component_url = storage_url.CloudUrl(destination_url.scheme,
destination_url.bucket_name,
component_object_name)
return resource_reference.UnknownResource(component_url)
def get_component_count(file_size, target_component_size, max_components):
"""Returns the # components a file would be split into for a composite upload.
Args:
file_size (int|None): Total byte size of file being divided into components.
None if could not be determined.
target_component_size (int|str): Target size for each component if not total
components isn't capped by max_components. May be byte count int or size
string (e.g. "50M").
max_components (int|None): Limit on allowed components regardless of
file_size and target_component_size. None indicates no limit.
Returns:
int: Number of components to split file into for composite upload.
"""
if file_size is None:
return 1
if isinstance(target_component_size, int):
target_component_size_bytes = target_component_size
else:
target_component_size_bytes = scaled_integer.ParseInteger(
target_component_size)
return min(
math.ceil(file_size / target_component_size_bytes),
max_components if max_components is not None and max_components >= 1 else
float('inf'))
def get_component_offsets_and_lengths(file_size, component_count):
"""Calculates start bytes and sizes for a multi-component copy operation.
Args:
file_size (int): Total byte size of file being divided into components.
component_count (int): Number of components to divide file into.
Returns:
List of component offsets and lengths: list[(offset, length)].
Total component count can be found by taking the length of the list.
"""
component_size = math.ceil(file_size / component_count)
component_offsets_and_lengths = []
for i in range(component_count):
offset = i * component_size
if offset >= file_size:
break
length = min(component_size, file_size - offset)
component_offsets_and_lengths.append((offset, length))
return component_offsets_and_lengths

View File

@@ -0,0 +1,195 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for copying a folder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import io
import os
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import errors as api_errors
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
class RenameFolderTask(copy_util.CopyTaskWithExitHandler):
"""Represents a command operation renaming a folder around the cloud."""
def __init__(
self,
source_resource,
destination_resource,
print_created_message=False,
user_request_args=None,
verbose=False,
):
"""Initializes RenameFolderTask. Parent class documents arguments."""
super(RenameFolderTask, self).__init__(
source_resource=source_resource,
destination_resource=destination_resource,
print_created_message=print_created_message,
user_request_args=user_request_args,
verbose=verbose,
)
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string
)
def execute(self, task_status_queue=None):
source_url = self._source_resource.storage_url
destination_url = self._destination_resource.storage_url
api_client = api_factory.get_api(source_url.scheme)
if task_status_queue is not None:
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=0,
length=0,
source_url=self._source_resource.storage_url,
destination_url=self._destination_resource.storage_url,
operation_name=task_status.OperationName.INTRA_CLOUD_COPYING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
else:
progress_callback = None
operation = api_client.rename_folder(
destination_url.bucket_name,
source_url.resource_name,
destination_url.resource_name,
)
if not operation.done:
api_client.wait_for_operation(operation)
self._print_created_message_if_requested(self._destination_resource)
if progress_callback:
progress_callback(0)
def __eq__(self, other):
if not isinstance(other, RenameFolderTask):
return NotImplemented
return (
self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._print_created_message == other._print_created_message
and self._user_request_args == other._user_request_args
and self._verbose == other._verbose
)
class CopyFolderTask(copy_util.CopyTaskWithExitHandler):
"""Represents a command operation copying a folder around the cloud."""
def __init__(
self,
source_resource,
destination_resource,
print_created_message=False,
user_request_args=None,
verbose=False,
):
"""Initializes RenameFolderTask. Parent class documents arguments."""
super(CopyFolderTask, self).__init__(
source_resource=source_resource,
destination_resource=destination_resource,
print_created_message=print_created_message,
user_request_args=user_request_args,
verbose=verbose,
)
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string
)
def execute(self, task_status_queue=None):
source_url = self._source_resource.storage_url
destination_url = self._destination_resource.storage_url
api_client = api_factory.get_api(source_url.scheme)
if task_status_queue is not None:
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=0,
length=0,
source_url=self._source_resource.storage_url,
destination_url=self._destination_resource.storage_url,
operation_name=task_status.OperationName.INTRA_CLOUD_COPYING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
else:
progress_callback = None
bucket_layout = api_client.get_storage_layout(destination_url.bucket_name)
# GetStorageLayout requires ListObjects permission to work.
# While for most cases, (especially in this code path) the user would
# have the permission, we do not want to absorb the error as this is an
# entirely new workflow (HNS buckets) and absorbing this would end up
# invoking upload_objects which would create objects instead of folders
# in an HNS bucket.
if (
bucket_layout
and getattr(bucket_layout, 'hierarchicalNamespace', None)
and bucket_layout.hierarchicalNamespace.enabled
):
# We are copying to an HNS bucket. This means we can and should use
# create_folders API.
try:
api_client.create_folder(
destination_url.bucket_name,
destination_url.resource_name,
is_recursive=True,
)
except api_errors.ConflictError:
# If the folder already exists, we can just skip this step.
pass
else:
# We are copying to a flat namespace bucket. This means we need to
# upload an empty object to create the folder.
request_config = request_config_factory.get_request_config(
destination_url,
content_type=request_config_factory.DEFAULT_CONTENT_TYPE,
size=None,
)
api_client.upload_object(
io.StringIO(''),
self._destination_resource,
request_config=request_config,
)
self._print_created_message_if_requested(self._destination_resource)
if progress_callback:
progress_callback(0)
def __eq__(self, other):
if not isinstance(other, CopyFolderTask):
return NotImplemented
return (
self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._print_created_message == other._print_created_message
and self._user_request_args == other._user_request_args
and self._verbose == other._verbose
)

View File

@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for copying a managed folder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import errors
from googlecloudsdk.api_lib.storage import gcs_iam_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
class CopyManagedFolderTask(copy_util.CopyTaskWithExitHandler):
"""Represents a command operation copying an object around the cloud."""
def __init__(
self,
source_resource,
destination_resource,
print_created_message=False,
user_request_args=None,
verbose=False,
):
"""Initializes CopyManagedFolderTask. Parent class documents arguments."""
super(CopyManagedFolderTask, self).__init__(
source_resource=source_resource,
destination_resource=destination_resource,
print_created_message=print_created_message,
user_request_args=user_request_args,
verbose=verbose,
)
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string
)
def execute(self, task_status_queue=None):
source_url = self._source_resource.storage_url
destination_url = self._destination_resource.storage_url
api_client = api_factory.get_api(source_url.scheme)
if task_status_queue is not None:
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=0,
length=0,
source_url=self._source_resource.storage_url,
destination_url=self._destination_resource.storage_url,
operation_name=task_status.OperationName.INTRA_CLOUD_COPYING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
else:
progress_callback = None
source_policy = api_client.get_managed_folder_iam_policy(
source_url.bucket_name, source_url.resource_name
)
try:
api_client.create_managed_folder(
destination_url.bucket_name,
destination_url.resource_name,
)
except errors.ConflictError:
pass
self._print_created_message_if_requested(self._destination_resource)
# Source etag will not match the destination causing precondition failures.
source_policy.etag = None
# Version must be specified.
source_policy.version = gcs_iam_util.IAM_POLICY_VERSION
api_client.set_managed_folder_iam_policy(
destination_url.bucket_name,
destination_url.resource_name,
source_policy,
)
if progress_callback:
progress_callback(0)
def __eq__(self, other):
if not isinstance(other, CopyManagedFolderTask):
return NotImplemented
return (
self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._print_created_message == other._print_created_message
and self._user_request_args == other._user_request_args
and self._verbose == other._verbose
)

View File

@@ -0,0 +1,220 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preferred method of generating a copy task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import posix_util
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.command_lib.storage.tasks.cp import copy_folder_task
from googlecloudsdk.command_lib.storage.tasks.cp import copy_managed_folder_task
from googlecloudsdk.command_lib.storage.tasks.cp import daisy_chain_copy_task
from googlecloudsdk.command_lib.storage.tasks.cp import file_download_task
from googlecloudsdk.command_lib.storage.tasks.cp import file_upload_task
from googlecloudsdk.command_lib.storage.tasks.cp import intra_cloud_copy_task
from googlecloudsdk.command_lib.storage.tasks.cp import parallel_composite_upload_util
from googlecloudsdk.command_lib.storage.tasks.cp import streaming_download_task
from googlecloudsdk.command_lib.storage.tasks.cp import streaming_upload_task
def get_copy_task(
source_resource,
destination_resource,
delete_source=False,
do_not_decompress=False,
fetch_source_fields_scope=None,
force_daisy_chain=False,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
shared_stream=None,
user_request_args=None,
verbose=False,
):
"""Factory method that returns the correct copy task for the arguments.
Args:
source_resource (resource_reference.Resource): Reference to file to copy.
destination_resource (resource_reference.Resource): Reference to destination
to copy file to.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
do_not_decompress (bool): Prevents automatically decompressing downloaded
gzips.
fetch_source_fields_scope (FieldsScope|None): If present, refetch
source_resource, populated with metadata determined by this FieldsScope.
Useful for lazy or parallelized GET calls. Currently only implemented for
intra-cloud copies and daisy chain copies.
force_daisy_chain (bool): If True, yields daisy chain copy tasks in place of
intra-cloud copy tasks.
posix_to_set (PosixAttributes|None): Triggers setting POSIX on result of
copy and avoids re-parsing POSIX info.
print_created_message (bool): Print the versioned URL of each successfully
copied object.
print_source_version (bool): Print source object version in status message
enabled by the `verbose` kwarg.
shared_stream (stream): Multiple tasks may reuse this read or write stream.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
verbose (bool): Print a "copying" status message on task initialization.
Returns:
Task object that can be executed to perform a copy.
Raises:
NotImplementedError: Cross-cloud copy.
Error: Local filesystem copy.
"""
source_url = source_resource.storage_url
destination_url = destination_resource.storage_url
if (isinstance(source_url, storage_url.FileUrl)
and isinstance(destination_url, storage_url.FileUrl)):
raise errors.Error(
'Local copies not supported. Gcloud command-line tool is'
' meant for cloud operations. Received copy from {} to {}'.format(
source_url, destination_url
)
)
if (isinstance(source_url, storage_url.CloudUrl)
and isinstance(destination_url, storage_url.FileUrl)):
if destination_url.is_stream:
return streaming_download_task.StreamingDownloadTask(
source_resource,
destination_resource,
shared_stream,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
return file_download_task.FileDownloadTask(
source_resource,
destination_resource,
delete_source=delete_source,
do_not_decompress=do_not_decompress,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
system_posix_data=posix_util.run_if_setting_posix(
posix_to_set, user_request_args, posix_util.get_system_posix_data
),
user_request_args=user_request_args,
verbose=verbose,
)
if (isinstance(source_url, storage_url.FileUrl)
and isinstance(destination_url, storage_url.CloudUrl)):
if source_url.is_stream:
return streaming_upload_task.StreamingUploadTask(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
else:
is_composite_upload_eligible = (
parallel_composite_upload_util.is_composite_upload_eligible(
source_resource, destination_resource, user_request_args))
return file_upload_task.FileUploadTask(
source_resource,
destination_resource,
delete_source=delete_source,
is_composite_upload_eligible=is_composite_upload_eligible,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
if (isinstance(source_url, storage_url.CloudUrl)
and isinstance(destination_url, storage_url.CloudUrl)):
different_providers = source_url.scheme != destination_url.scheme
if (different_providers and user_request_args and
user_request_args.resource_args and
user_request_args.resource_args.preserve_acl):
raise errors.Error(
'Cannot preserve ACLs while copying between cloud providers.'
)
# If the source_resource is a folder and other conditions for rename_folders
# are met, we need not invoke the CopyManagedFolderTask
# as the CopyFolderTask would take care of it automaticlally.
is_folders_use_case = (
isinstance(source_resource, resource_reference.FolderResource)
and not different_providers
)
if (
is_folders_use_case
and delete_source
and not force_daisy_chain
and source_resource.bucket
== destination_resource.storage_url.bucket_name
):
return copy_folder_task.RenameFolderTask(
source_resource,
destination_resource,
print_created_message=print_created_message,
user_request_args=user_request_args,
verbose=verbose,
)
elif is_folders_use_case:
return copy_folder_task.CopyFolderTask(
source_resource,
destination_resource,
print_created_message=print_created_message,
user_request_args=user_request_args,
verbose=verbose,
)
elif isinstance(source_resource, resource_reference.ManagedFolderResource):
return copy_managed_folder_task.CopyManagedFolderTask(
source_resource,
destination_resource,
print_created_message=print_created_message,
user_request_args=user_request_args,
verbose=verbose,
)
if different_providers or force_daisy_chain:
return daisy_chain_copy_task.DaisyChainCopyTask(
source_resource,
destination_resource,
delete_source=delete_source,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
fetch_source_fields_scope=fetch_source_fields_scope,
)
return intra_cloud_copy_task.IntraCloudCopyTask(
source_resource,
destination_resource,
delete_source=delete_source,
fetch_source_fields_scope=fetch_source_fields_scope,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)

View File

@@ -0,0 +1,837 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task iterator for copy functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import folder_util
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import path_util
from googlecloudsdk.command_lib.storage import plurality_checkable_iterator
from googlecloudsdk.command_lib.storage import posix_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage import wildcard_iterator
from googlecloudsdk.command_lib.storage.resources import gcs_resource_reference
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.command_lib.storage.resources import resource_util
from googlecloudsdk.command_lib.storage.tasks.cp import copy_task_factory
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
_ONE_TB_IN_BYTES = 1099511627776
_RELATIVE_PATH_SYMBOLS = frozenset(['.', '..'])
def _expand_destination_wildcards(destination_string, folders_only=False):
"""Expands destination wildcards.
Ensures that only one resource matches the wildcard expanded string. Much
like the unix cp command, the storage surface only supports copy operations
to one user-specified destination.
Args:
destination_string (str): A string representing the destination url.
folders_only (bool): If True, indicates that we are invoking folders only
copy task.
Returns:
A resource_reference.Resource, or None if no matching resource is found.
Raises:
InvalidUrlError if more than one resource is matched, or the source
contained an unescaped wildcard and no resources were matched.
"""
destination_iterator = (
plurality_checkable_iterator.PluralityCheckableIterator(
wildcard_iterator.get_wildcard_iterator(
destination_string,
folder_setting=folder_util.FolderSetting.LIST_AS_FOLDERS
if folders_only
else folder_util.FolderSetting.DO_NOT_LIST,
fields_scope=cloud_api.FieldsScope.SHORT,
)
)
)
if destination_iterator.is_plural():
# If the result is plural, we are bound to throw an error.
# But we also should check if this is a case of duplicate results due to a
# placeholder folder which was created through the UI.
# If it is not the case, we continue with raising the Error and not moving
# further with the method.
# If it is the case of duplicates, we do not raise the Error, and rather
# continue with method execution as planned.
resolved_resource = _resolve_duplicate_ui_folder_destination(
destination_string, destination_iterator
)
if not resolved_resource:
raise errors.InvalidUrlError(
f'Destination ({destination_string}) must match exactly one URL.'
)
destination_iterator = (
plurality_checkable_iterator.PluralityCheckableIterator(
[resolved_resource]
)
)
contains_unexpanded_wildcard = (
destination_iterator.is_empty()
and wildcard_iterator.contains_wildcard(destination_string)
)
if contains_unexpanded_wildcard:
raise errors.InvalidUrlError(
f'Destination ({destination_string}) contains an unexpected wildcard.'
)
if not destination_iterator.is_empty():
return next(destination_iterator)
def _resolve_duplicate_ui_folder_destination(
destination_string, destination_iterator
):
"""Resolves duplicate resource results for placeholder folders created through the UI.
In the scenario where a user creates a placeholder folder
(which is actually an object ending with a '/' rather than a true folder as in
the case of HNS buckets), the CLI, when resolving for destination gets
two results as part of the ListObjects API call. One of these is of type
GCSObjectResource, while the other is PrefixResource. Technically both results
are correct and expected. But in our logic, we end up interpretting this case
as multiple destinations which we do not support.
This method determines if the given results come under the above scenario.
Args:
destination_string (str): A string representing the destination url.
destination_iterator (PluralityCheckableIterator): Contains results from the
destination search through the wildcard iterator.
Returns:
PrefixResource out of the two results of duplicate resources due to UI
folder creation, None otherwise.
"""
# The first condition would be to make sure that the destination string
# is of the type CloudURL and a GCS schema, because this case does not apply
# to any other type of destination.
destination_storage_url = storage_url.storage_url_from_string(
destination_string
)
if (
not isinstance(destination_storage_url, storage_url.CloudUrl)
or destination_storage_url.scheme != storage_url.ProviderPrefix.GCS
):
return None
destination_resource_1 = next(destination_iterator)
destination_resource_2 = next(destination_iterator)
# In case of a Folder created through the UI, we expect two resources.
# We never expect more than that to exist. So if we do encounter that case,
# then this is not the scenario of a UI created folder.
if not destination_iterator.is_empty():
return None
# Types of both resources cannot be the same since we expect a mix of
# GCSResourceReference and PrefixResource to be returned
# from the WildcardIterator in the case of Folders which are a part of the UI.
if isinstance(destination_resource_1, type(destination_resource_2)):
return None
# At least one of the resource has to be of type GcsObjectResource.
if not (
isinstance(
destination_resource_1, gcs_resource_reference.GcsObjectResource
)
or isinstance(
destination_resource_2, gcs_resource_reference.GcsObjectResource
)
):
return None
# Once we have determined that at least one of the resource is of type
# GcsObjectResource, we need to ensure that one of them is PrefixResource.
# In the case where we have two GcsObjectResource or one of them is not of
# type PrefixResource, we will return False as this is not a UI created folder
# case for sure.
if not (
isinstance(destination_resource_1, resource_reference.PrefixResource)
or isinstance(destination_resource_2, resource_reference.PrefixResource)
):
return None
if (
destination_resource_1.storage_url.versionless_url_string.endswith('/')
and destination_resource_2.storage_url.versionless_url_string.endswith(
'/'
)
) and (
destination_resource_1.storage_url.versionless_url_string
== destination_resource_2.storage_url.versionless_url_string
):
return (
destination_resource_1
if isinstance(destination_resource_1, resource_reference.PrefixResource)
else destination_resource_2
)
return None
def _get_raw_destination(destination_string, folders_only=False):
"""Converts self._destination_string to a destination resource.
Args:
destination_string (str): A string representing the destination url.
folders_only (bool): If True, indicates that we are invoking folders only
copy task.
Returns:
A resource_reference.Resource. Note that this resource may not be a valid
copy destination if it is a BucketResource, PrefixResource,
FileDirectoryResource or UnknownResource.
Raises:
InvalidUrlError if the destination url is a cloud provider or if it
specifies
a version.
"""
destination_url = storage_url.storage_url_from_string(destination_string)
if isinstance(destination_url, storage_url.CloudUrl):
if destination_url.is_provider():
raise errors.InvalidUrlError(
'The cp command does not support provider-only destination URLs.'
)
elif destination_url.generation is not None:
raise errors.InvalidUrlError(
'The destination argument of the cp command cannot be a '
'version-specific URL ({}).'.format(destination_string)
)
raw_destination = _expand_destination_wildcards(
destination_string, folders_only
)
if raw_destination:
return raw_destination
return resource_reference.UnknownResource(destination_url)
def _destination_is_container(destination):
"""Returns True is the destination can be treated as a container.
For a CloudUrl, a container is a bucket or a prefix. If the destination does
not exist, we determine this based on the delimiter.
For a FileUrl, A container is an existing dir. For non existing path, we
return False.
Args:
destination (resource_reference.Resource): The destination container.
Returns:
bool: True if destination is a valid container.
"""
try:
if destination.is_container():
return True
except errors.ValueCannotBeDeterminedError:
# Some resource classes are not clearly containers, like objects with names
# ending in a delimiter. However, we want to treat them as containers anways
# so that nesting at copy destinations will work as expected.
pass
destination_url = destination.storage_url
if isinstance(destination_url, storage_url.FileUrl):
# We don't want to treat non-existing file paths as valid containers.
return os.path.isdir(destination_url.resource_name)
return (destination_url.versionless_url_string.endswith(
destination_url.delimiter) or
(isinstance(destination_url, storage_url.CloudUrl) and
destination_url.is_bucket()))
def _resource_is_stream(resource):
"""Checks if a resource points to local pipe-type."""
return (isinstance(resource.storage_url, storage_url.FileUrl) and
resource.storage_url.is_stream)
def _is_expanded_url_valid_parent_dir(expanded_url):
"""Returns True if not FileUrl ending in relative path symbols.
A URL is invalid if it is a FileUrl and the parent directory of the file is a
relative path symbol. Unix will not allow a file itself to be named with a
relative path symbol, but one can be the parent. Notably, "../obj" can lead
to unexpected behavior at the copy destination. We examine the pre-recursion
expanded_url, which might point to "..", to see if the parent is valid.
If the user does a recursive copy from an expanded URL, it may not end up
the final parent of the copied object. For example, see: "dir/nested_dir/obj".
If you ran "cp -r d* gs://bucket" from the parent of "dir", then the
expanded_url would be "dir", but "nested_dir" would be the parent of "obj".
This actually doesn't matter since recursion won't add relative path symbols
to the path. However, we still return if expanded_url is valid because
there are cases where we need to copy every parent directory up to
expanded_url "dir" to prevent file name conflicts.
Args:
expanded_url (StorageUrl): NameExpansionResult.expanded_url value. Should
contain wildcard-expanded URL before recursion. For example, if "d*"
expands to the object "dir/obj", we would get the "dir" value.
Returns:
Boolean indicating if the expanded_url is valid as a parent
directory.
"""
if not isinstance(expanded_url, storage_url.FileUrl):
return True
_, _, last_string_following_delimiter = (
expanded_url.versionless_url_string.rstrip(
expanded_url.delimiter).rpartition(expanded_url.delimiter))
return last_string_following_delimiter not in _RELATIVE_PATH_SYMBOLS and (
last_string_following_delimiter not in [
expanded_url.scheme.value + '://' + symbol
for symbol in _RELATIVE_PATH_SYMBOLS
])
class CopyTaskIterator:
"""Iterates over each expanded source and creates an appropriate copy task."""
def __init__(
self,
source_name_iterator,
destination_string,
custom_md5_digest=None,
delete_source=False,
do_not_decompress=False,
force_daisy_chain=False,
print_created_message=False,
shared_stream=None,
skip_unsupported=True,
task_status_queue=None,
user_request_args=None,
folders_only=False,
):
"""Initializes a CopyTaskIterator instance.
Args:
source_name_iterator (name_expansion.NameExpansionIterator): yields
resource_reference.Resource objects with expanded source URLs.
destination_string (str): The copy destination path or url.
custom_md5_digest (str|None): User-added MD5 hash output to send to server
for validating a single resource upload.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
do_not_decompress (bool): Prevents automatically decompressing downloaded
gzips.
force_daisy_chain (bool): If True, yields daisy chain copy tasks in place
of intra-cloud copy tasks.
print_created_message (bool): Print the versioned URL of each successfully
copied object.
shared_stream (stream): Multiple tasks may reuse a read or write stream.
skip_unsupported (bool): Skip creating copy tasks for unsupported object
types.
task_status_queue (multiprocessing.Queue|None): Used for estimating total
workload from this iterator.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
folders_only (bool): If True, perform only folders tasks.
"""
self._all_versions = (
source_name_iterator.object_state
is cloud_api.ObjectState.LIVE_AND_NONCURRENT
)
self._has_multiple_top_level_sources = (
source_name_iterator.has_multiple_top_level_resources)
self._has_cloud_source = False
self._has_local_source = False
self._source_name_iterator = (
plurality_checkable_iterator.PluralityCheckableIterator(
source_name_iterator))
self._multiple_sources = self._source_name_iterator.is_plural()
self._custom_md5_digest = custom_md5_digest
self._delete_source = delete_source
self._do_not_decompress = do_not_decompress
self._force_daisy_chain = force_daisy_chain
self._print_created_message = print_created_message
self._shared_stream = shared_stream
self._skip_unsupported = skip_unsupported
self._task_status_queue = task_status_queue
self._user_request_args = user_request_args
self._folders_only = folders_only
self._total_file_count = 0
self._total_size = 0
self._raw_destination = _get_raw_destination(
destination_string, self._folders_only
)
if self._multiple_sources:
self._raise_if_destination_is_file_url_and_not_a_directory_or_pipe()
else:
# For multiple sources,
# _raise_if_destination_is_file_url_and_not_a_directory_or_pipe already
# checks for directory's existence.
self._raise_if_download_destination_ends_with_delimiter_and_does_not_exist()
if self._multiple_sources and self._custom_md5_digest:
raise errors.Error(
'Received multiple objects to upload, but only one'
' custom MD5 digest is allowed.'
)
self._already_completed_sources = manifest_util.parse_for_completed_sources(
getattr(user_request_args, 'manifest_path', None))
def _raise_error_if_source_matches_destination(self):
if not self._multiple_sources and not self._source_name_iterator.is_empty():
source_url = self._source_name_iterator.peek().expanded_url
if source_url == self._raw_destination.storage_url:
raise errors.InvalidUrlError(
'Source URL matches destination URL: {}'.format(source_url))
def _raise_error_if_expanded_source_matches_expanded_destination(
self, expanded_source_url, expanded_destination_url
):
if expanded_source_url == expanded_destination_url:
raise errors.InvalidUrlError(
'Destination URL {} already exists.'.format(expanded_destination_url)
)
def _raise_if_destination_is_file_url_and_not_a_directory_or_pipe(self):
if (isinstance(self._raw_destination.storage_url, storage_url.FileUrl) and
not (_destination_is_container(self._raw_destination) or
self._raw_destination.storage_url.is_stream)):
raise errors.InvalidUrlError(
'Destination URL must name an existing directory.'
' Provided: {}.'.format(
self._raw_destination.storage_url.resource_name))
def _raise_if_download_destination_ends_with_delimiter_and_does_not_exist(
self,
):
if isinstance(self._raw_destination.storage_url, storage_url.FileUrl):
# Download operation.
destination_path = self._raw_destination.storage_url.resource_name
if destination_path.endswith(
self._raw_destination.storage_url.delimiter
) and not self._raw_destination.storage_url.isdir():
raise errors.InvalidUrlError(
'Destination URL must name an existing directory if it ends with a'
' delimiter. Provided: {}.'.format(destination_path)
)
def _update_workload_estimation(self, resource):
"""Updates total_file_count and total_size.
Args:
resource (resource_reference.Resource): Any type of resource. Parse to
help estimate total workload.
"""
if self._total_file_count == -1 or self._total_size == -1:
# -1 is signal that data is corrupt and not worth tracking.
return
try:
if resource.is_container():
return
size = resource.size
if isinstance(resource, resource_reference.FileObjectResource):
self._has_local_source = True
elif isinstance(resource, resource_reference.ObjectResource):
self._has_cloud_source = True
else:
raise errors.ValueCannotBeDeterminedError
except (OSError, errors.ValueCannotBeDeterminedError):
if not _resource_is_stream(resource):
log.error('Could not get size of resource {}.'.format(resource))
self._total_file_count = -1
self._total_size = -1
else:
self._total_file_count += 1
self._total_size += size or 0
def _print_skip_and_maybe_send_to_manifest(self, message, source):
"""Prints why task is being skipped and maybe records in manifest."""
log.status.Print(message)
if (
self._user_request_args
and self._user_request_args.manifest_path
and self._task_status_queue
):
manifest_util.send_skip_message(
self._task_status_queue,
source.resource,
self._raw_destination,
message,
)
def __iter__(self):
self._raise_error_if_source_matches_destination()
is_source_plural = self._source_name_iterator.is_plural()
for source in self._source_name_iterator:
if self._folders_only and not isinstance(
source.resource, resource_reference.FolderResource
):
continue
if self._delete_source:
copy_util.raise_if_mv_early_deletion_fee_applies(source.resource)
if self._skip_unsupported:
unsupported_type = resource_util.get_unsupported_object_type(
source.resource)
if unsupported_type:
message = resource_util.UNSUPPORTED_OBJECT_WARNING_FORMAT.format(
source.resource.storage_url, unsupported_type.value
)
self._print_skip_and_maybe_send_to_manifest(message, source)
continue
if (
source.resource.storage_url.url_string
in self._already_completed_sources
):
message = (
'Skipping item {} because manifest marks it as'
' skipped or completed.'
).format(source.resource.storage_url)
self._print_skip_and_maybe_send_to_manifest(message, source)
continue
destination_resource = self._get_copy_destination(
self._raw_destination, source, is_source_plural
)
source_url = source.resource.storage_url
destination_url = destination_resource.storage_url
self._raise_error_if_expanded_source_matches_expanded_destination(
source_url, destination_url
)
if (
self._folders_only
and self._delete_source
and (
source_url.scheme != destination_url.scheme
or source_url.bucket_name != destination_url.bucket_name
)
):
continue
posix_util.run_if_setting_posix(
posix_to_set=None,
user_request_args=self._user_request_args,
function=posix_util.raise_if_source_and_destination_not_valid_for_preserve_posix,
source_url=source_url,
destination_url=destination_url,
)
if (isinstance(source.resource, resource_reference.ObjectResource) and
isinstance(destination_url, storage_url.FileUrl) and
destination_url.resource_name.endswith(destination_url.delimiter)):
log.debug('Skipping downloading {} to {} since the destination ends in'
' a file system delimiter.'.format(
source_url.versionless_url_string,
destination_url.versionless_url_string))
continue
if (not self._multiple_sources and source_url.versionless_url_string !=
source.expanded_url.versionless_url_string):
# Multiple sources have been already validated in __init__.
# This check is required for cases where recursion has been requested,
# but there is only one object that needs to be copied over.
self._raise_if_destination_is_file_url_and_not_a_directory_or_pipe()
if self._custom_md5_digest:
source.resource.md5_hash = self._custom_md5_digest
self._update_workload_estimation(source.resource)
yield copy_task_factory.get_copy_task(
source.resource,
destination_resource,
do_not_decompress=self._do_not_decompress,
delete_source=self._delete_source,
force_daisy_chain=self._force_daisy_chain,
print_created_message=self._print_created_message,
print_source_version=(
source.original_url.generation or self._all_versions
),
shared_stream=self._shared_stream,
verbose=True,
user_request_args=self._user_request_args,
)
if self._task_status_queue and (
self._total_file_count > 0 or self._total_size > 0
):
# Show fraction of total copies completed now that we know totals.
progress_callbacks.workload_estimator_callback(
self._task_status_queue,
item_count=self._total_file_count,
size=self._total_size,
)
if (
self._total_size > _ONE_TB_IN_BYTES
and self._has_cloud_source
and not self._has_local_source
and self._raw_destination.storage_url.scheme
is storage_url.ProviderPrefix.GCS
and properties.VALUES.storage.suggest_transfer.GetBool()
):
log.status.Print(
'For large copies, consider the `gcloud transfer jobs create ...`'
' command. Learn more at'
'\nhttps://cloud.google.com/storage-transfer-service'
'\nRun `gcloud config set storage/suggest_transfer False` to'
' disable this message.'
)
def _get_copy_destination(
self, raw_destination, source, is_source_plural=False
):
"""Returns the final destination StorageUrl instance."""
completion_is_necessary = (
_destination_is_container(raw_destination)
or (self._multiple_sources and not _resource_is_stream(raw_destination))
or source.resource.storage_url.versionless_url_string
!= source.expanded_url.versionless_url_string # Recursion case.
)
if completion_is_necessary:
if (
isinstance(source.expanded_url, storage_url.FileUrl)
and source.expanded_url.is_stdio
):
raise errors.Error(
'Destination object name needed when source is stdin.'
)
destination_resource = self._complete_destination(
raw_destination, source, is_source_plural
)
else:
destination_resource = raw_destination
sanitized_destination_resource = (
path_util.sanitize_file_resource_for_windows(destination_resource)
)
return sanitized_destination_resource
def _complete_destination(
self, destination_container, source, is_source_plural=False
):
"""Gets a valid copy destination incorporating part of the source's name.
When given a source file or object and a destination resource that should
be treated as a container, this function uses the last part of the source's
name to get an object or file resource representing the copy destination.
For example: given a source `dir/file` and a destination `gs://bucket/`, the
destination returned is a resource representing `gs://bucket/file`. Check
the recursive helper function docstring for details on recursion handling.
Args:
destination_container (resource_reference.Resource): The destination
container.
source (NameExpansionResult): Represents the source resource and the
expanded parent url in case of recursion.
is_source_plural (bool): True if the source is a plural resource.
Returns:
The completed destination, a resource_reference.Resource.
"""
destination_url = destination_container.storage_url
source_url = source.resource.storage_url
if (
source_url.versionless_url_string
!= source.expanded_url.versionless_url_string
):
# In case of recursion, the expanded_url can be the expanded wildcard URL
# representing the container, and the source url can be the file/object.
destination_suffix = self._get_destination_suffix_for_recursion(
destination_container, source
)
else:
# On Windows with a relative path URL like file://file.txt, partitioning
# on the delimiter will fail to remove file://, so destination_suffix
# would include the scheme. We remove the scheme here to avoid this.
_, _, url_without_scheme = source_url.versionless_url_string.rpartition(
source_url.scheme.value + '://'
)
# Ignores final slashes when completing names. For example, where
# source_url is gs://bucket/folder/ and destination_url is gs://bucket1,
# the completed URL should be gs://bucket1/folder/.
if url_without_scheme.endswith(source_url.delimiter):
url_without_scheme_and_trailing_delimiter = (
url_without_scheme[:-len(source_url.delimiter)]
)
else:
url_without_scheme_and_trailing_delimiter = url_without_scheme
_, _, destination_suffix = (
url_without_scheme_and_trailing_delimiter.rpartition(
source_url.delimiter
)
)
if url_without_scheme_and_trailing_delimiter != url_without_scheme:
# Adds the removed delimiter back.
destination_suffix += source_url.delimiter
destination_url_prefix = storage_url.storage_url_from_string(
destination_url.versionless_url_string.rstrip(destination_url.delimiter)
)
# For folders use-case, we want to rename/copy to the folder as the name
# of the destination if it does not exist. This is similar to the Filesystem
# and does not happen for flat buckets today. Hence this additional logic.
if (
self._folders_only
and isinstance(source.resource, resource_reference.FolderResource)
and not isinstance(
destination_container, resource_reference.FolderResource
)
and not is_source_plural
):
return resource_reference.UnknownResource(destination_url_prefix)
new_destination_url = destination_url_prefix.join(destination_suffix)
return resource_reference.UnknownResource(new_destination_url)
def _get_destination_suffix_for_recursion(
self, destination_container, source
):
"""Returns the suffix required to complete the destination URL.
Let's assume the following:
User command => cp -r */base_dir gs://dest/existing_prefix
source.resource.storage_url => a/base_dir/c/d.txt
source.expanded_url => a/base_dir
destination_container.storage_url => gs://dest/existing_prefix
If the destination container exists, the entire directory gets copied:
Result => gs://dest/existing_prefix/base_dir/c/d.txt
Args:
destination_container (resource_reference.Resource): The destination
container.
source (NameExpansionResult): Represents the source resource and the
expanded parent url in case of recursion.
Returns:
(str) The suffix to be appended to the destination container.
"""
source_prefix_to_ignore = storage_url.rstrip_one_delimiter(
source.expanded_url.versionless_url_string,
source.expanded_url.delimiter,
)
expanded_url_is_valid_parent = _is_expanded_url_valid_parent_dir(
source.expanded_url
)
if (
not expanded_url_is_valid_parent
and self._has_multiple_top_level_sources
):
# To avoid top-level name conflicts, we need to copy the parent dir.
# However, that cannot be done because the parent dir has an invalid name.
raise errors.InvalidUrlError(
'Presence of multiple top-level sources and invalid expanded URL'
' make file name conflicts possible for URL: {}'.format(
source.resource
)
)
is_top_level_source_object_name_conflict_possible = (
isinstance(destination_container, resource_reference.UnknownResource)
and self._has_multiple_top_level_sources
)
destination_exists = not isinstance(
destination_container, resource_reference.UnknownResource
)
destination_is_existing_dir = (
destination_exists and destination_container.is_container()
)
treat_destination_as_existing_dir = destination_is_existing_dir or (
not destination_exists
and destination_container.storage_url.url_string.endswith(
destination_container.storage_url.delimiter
)
)
if is_top_level_source_object_name_conflict_possible or (
expanded_url_is_valid_parent and treat_destination_as_existing_dir
):
# Remove the leaf name unless it is a relative path symbol, so that
# only top-level source directories are ignored.
# Presence of relative path symbols needs to be checked with the source
# to distinguish file://dir.. from file://dir/..
source_delimiter = source.resource.storage_url.delimiter
relative_path_characters_end_source_prefix = [
source_prefix_to_ignore.endswith(source_delimiter + i)
for i in _RELATIVE_PATH_SYMBOLS
]
# On Windows, source paths that are relative path symbols will not contain
# the source delimiter, e.g. file://.. This case thus needs to be detected
# separately.
source_url_scheme_string = source.expanded_url.scheme.value + '://'
source_prefix_to_ignore_without_scheme = source_prefix_to_ignore[
len(source_url_scheme_string):]
source_is_relative_path_symbol = (
source_prefix_to_ignore_without_scheme in _RELATIVE_PATH_SYMBOLS)
if (not any(relative_path_characters_end_source_prefix) and
not source_is_relative_path_symbol):
source_prefix_to_ignore, _, _ = source_prefix_to_ignore.rpartition(
source.expanded_url.delimiter)
if not source_prefix_to_ignore:
# In case of Windows, the source URL might not contain any Windows
# delimiter if it was a single directory (e.g file://dir) and
# source_prefix_to_ignore will be empty. Set it to <scheme>://.
# TODO(b/169093672) This will not be required if we get rid of file://
source_prefix_to_ignore = source.expanded_url.scheme.value + '://'
full_source_url = source.resource.storage_url.versionless_url_string
delimiter = source.resource.storage_url.delimiter
suffix_for_destination = delimiter + (
full_source_url.split(source_prefix_to_ignore)[1]
).lstrip(delimiter)
# Windows uses \ as a delimiter. Force the suffix to use the same
# delimiter used by the destination container.
source_delimiter = source.resource.storage_url.delimiter
destination_delimiter = destination_container.storage_url.delimiter
if source_delimiter != destination_delimiter:
return suffix_for_destination.replace(
source_delimiter, destination_delimiter
)
return suffix_for_destination

View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities for copies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import datetime
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import errors as api_errors
from googlecloudsdk.command_lib.storage import errors as command_errors
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.command_lib.storage.resources import resource_util
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
_EARLY_DELETION_MINIMUM_DAYS = {
'nearline': 30,
'coldline': 90,
'archive': 365,
}
class CopyTask(task.Task):
"""Parent task that handles common attributes and an __init__ status print."""
def __init__(
self,
source_resource,
destination_resource,
print_created_message=False,
print_source_version=False,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (resource_reference.Resource): Source resource to copy.
destination_resource (resource_reference.Resource): Target resource to
copy to.
print_created_message (bool): Print a message containing the URL of the
copy result.
print_source_version (bool): Print source object version in status message
enabled by the `verbose` kwarg.
user_request_args (UserRequestArgs|None): Various user-set values
typically converted to an API-specific RequestConfig.
verbose (bool): Print a "copying" status message on initialization.
"""
super(CopyTask, self).__init__()
self._source_resource = source_resource
self._destination_resource = destination_resource
self._print_created_message = print_created_message
self._print_source_version = print_source_version
self._user_request_args = user_request_args
self._verbose = verbose
self._send_manifest_messages = bool(
self._user_request_args and self._user_request_args.manifest_path
)
if verbose:
if self._print_source_version:
source_string = source_resource.storage_url.url_string
else:
source_string = source_resource.storage_url.versionless_url_string
log.status.Print(
'Copying {} to {}'.format(
source_string,
destination_resource.storage_url.versionless_url_string,
)
)
def _print_created_message_if_requested(self, resource):
if self._print_created_message:
log.status.Print('Created: {}'.format(resource))
class ObjectCopyTask(CopyTask):
"""Parent task that handles common attributes for object copy tasks."""
def __init__(
self,
source_resource,
destination_resource,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (resource_reference.Resource): See parent class.
destination_resource (resource_reference.Resource): See parent class.
posix_to_set (PosixAttributes|None): POSIX info set as custom cloud
metadata on target.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
user_request_args (UserRequestArgs|None): See parent class.
verbose (bool): Print a "copying" status message on initialization.
"""
self._posix_to_set = posix_to_set
# Set before super().__init__ call because otherwise the attribute won't be
# available for the _get_source_string_for_status_message call.
self._print_source_version = print_source_version
super(ObjectCopyTask, self).__init__(
source_resource,
destination_resource,
print_created_message,
print_source_version,
user_request_args,
verbose,
)
class _ExitHandlerMixin:
"""Provides an exit handler for copy tasks."""
def exit_handler(self, error=None, task_status_queue=None):
"""Send copy result info to manifest if requested."""
if error and self._send_manifest_messages:
if not task_status_queue:
raise command_errors.Error(
'Unable to send message to manifest for source: {}'.format(
self._source_resource
)
)
manifest_util.send_error_message(task_status_queue, self._source_resource,
self._destination_resource, error)
class CopyTaskWithExitHandler(
# _ExitHandlerMixin must precede CopyTask, otherwise task.Task.exit_hander
# overrides the intended implementation.
_ExitHandlerMixin,
CopyTask,
):
"""Parent task with an exit handler for non-object copy tasks."""
class ObjectCopyTaskWithExitHandler(_ExitHandlerMixin, ObjectCopyTask):
"""Parent task with an exit handler for object copy tasks."""
def get_no_clobber_message(destination_url):
"""Returns standardized no clobber warning."""
return 'Skipping existing destination item (no-clobber): {}'.format(
destination_url)
def check_for_cloud_clobber(user_request_args, api_client,
destination_resource):
"""Returns if cloud destination object exists if no-clobber enabled."""
if not (user_request_args and user_request_args.no_clobber):
return False
try:
api_client.get_object_metadata(
destination_resource.storage_url.bucket_name,
destination_resource.storage_url.resource_name,
fields_scope=cloud_api.FieldsScope.SHORT)
except api_errors.NotFoundError:
return False
return True
def get_generation_match_value(request_config):
"""Prioritizes user-input generation over no-clobber zero value."""
if request_config.precondition_generation_match is not None:
return request_config.precondition_generation_match
if request_config.no_clobber:
return 0
return None
def raise_if_mv_early_deletion_fee_applies(object_resource):
"""Raises error if Google Cloud Storage object will incur an extra charge."""
if isinstance(object_resource, resource_reference.FolderResource):
return
if not (properties.VALUES.storage.check_mv_early_deletion_fee.GetBool() and
object_resource.storage_url.scheme is storage_url.ProviderPrefix.GCS
and object_resource.creation_time and
object_resource.storage_class in _EARLY_DELETION_MINIMUM_DAYS):
return
minimum_lifetime = _EARLY_DELETION_MINIMUM_DAYS[
object_resource.storage_class.lower()]
creation_datetime_utc = resource_util.convert_datetime_object_to_utc(
object_resource.creation_time)
current_datetime_utc = resource_util.convert_datetime_object_to_utc(
datetime.datetime.now())
if current_datetime_utc < creation_datetime_utc + datetime.timedelta(
days=minimum_lifetime):
raise exceptions.Error(
('Deleting {} may incur an early deletion charge. Note: the source'
' object of a mv operation is deleted.\nThe object appears to have'
' been created on {}, and the minimum time before deletion for the {}'
' storage class is {} days.\nTo allow deleting the object anyways, run'
' "gcloud config set storage/check_mv_early_deletion_fee False"'
).format(object_resource, object_resource.creation_time,
object_resource.storage_class, minimum_lifetime))

View File

@@ -0,0 +1,691 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for daisy-chain copies.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import collections
import copy
import io
import os
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.cp import upload_util
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
_MAX_ALLOWED_READ_SIZE = 100 * 1024 * 1024 # 100 MiB
_MAX_BUFFER_QUEUE_SIZE = 100
# TODO(b/174075495) Determine the max size based on the destination scheme.
_QUEUE_ITEM_MAX_SIZE = 8 * 1024 # 8 KiB
_PROGRESS_CALLBACK_THRESHOLD = 16 * 1024 * 1024 # 16 MiB.
class _AbruptShutdownError(errors.Error):
"""Raised if a thread is terminated because of an error in another thread."""
class _WritableStream:
"""A write-only stream class that writes to the buffer queue."""
def __init__(self, buffer_queue, buffer_condition, shutdown_event):
"""Initializes WritableStream.
Args:
buffer_queue (collections.deque): A queue where the data gets written.
buffer_condition (threading.Condition): The condition object to wait on if
the buffer is full.
shutdown_event (threading.Event): Used for signaling the thread to
terminate.
"""
self._buffer_queue = buffer_queue
self._buffer_condition = buffer_condition
self._shutdown_event = shutdown_event
def write(self, data):
"""Writes data to the buffer queue.
This method writes the data in chunks of QUEUE_ITEM_MAX_SIZE. In most cases,
the read operation is performed with size=QUEUE_ITEM_MAX_SIZE.
Splitting the data in QUEUE_ITEM_MAX_SIZE chunks improves the performance.
This method will be blocked if MAX_BUFFER_QUEUE_SIZE is reached to avoid
writing all the data in-memory.
Args:
data (bytes): The bytes that should be added to the queue.
Raises:
_AbruptShutdownError: If self._shudown_event was set.
"""
start = 0
end = min(start + _QUEUE_ITEM_MAX_SIZE, len(data))
while start < len(data):
with self._buffer_condition:
while (len(self._buffer_queue) >= _MAX_BUFFER_QUEUE_SIZE and
not self._shutdown_event.is_set()):
self._buffer_condition.wait()
if self._shutdown_event.is_set():
raise _AbruptShutdownError()
self._buffer_queue.append(data[start:end])
start = end
end = min(start + _QUEUE_ITEM_MAX_SIZE, len(data))
self._buffer_condition.notify_all()
class _ReadableStream:
"""A read-only stream that reads from the buffer queue."""
def __init__(self, buffer_queue, buffer_condition, shutdown_event,
end_position, restart_download_callback,
progress_callback=None,
seekable=True):
"""Initializes ReadableStream.
Args:
buffer_queue (collections.deque): The underlying queue from which the data
gets read.
buffer_condition (threading.Condition): The condition object to wait on if
the buffer is empty.
shutdown_event (threading.Event): Used for signaling the thread to
terminate.
end_position (int): Position at which the stream reading stops. This is
usually the total size of the data that gets read.
restart_download_callback (func): This must be the
BufferController.restart_download function.
progress_callback (progress_callbacks.FilesAndBytesProgressCallback):
Accepts processed bytes and submits progress info for aggregation.
seekable (bool): Value for the "seekable" method call.
"""
self._buffer_queue = buffer_queue
self._buffer_condition = buffer_condition
self._end_position = end_position
self._shutdown_event = shutdown_event
self._position = 0
self._unused_data_from_previous_read = b''
self._progress_callback = progress_callback
self._restart_download_callback = restart_download_callback
self._bytes_read_since_last_progress_callback = 0
self._seekable = seekable
self._is_closed = False
def _restart_download(self, offset):
self._restart_download_callback(offset)
self._unused_data_from_previous_read = b''
self._bytes_read_since_last_progress_callback = 0
self._position = offset
def read(self, size=-1):
"""Reads size bytes from the buffer queue and returns it.
This method will be blocked if the buffer_queue is empty.
If size > length of data available, the entire data is sent over.
Args:
size (int): The number of bytes to be read.
Returns:
Bytes of length 'size'. May return bytes of length less than the size
if there are no more bytes left to be read.
Raises:
_AbruptShutdownError: If self._shudown_event was set.
storage.errors.Error: If size is not within the allowed range of
[-1, MAX_ALLOWED_READ_SIZE] OR
If size is -1 but the object size is greater than MAX_ALLOWED_READ_SIZE.
"""
if size == 0:
return b''
if size > _MAX_ALLOWED_READ_SIZE:
raise errors.Error(
'Invalid HTTP read size {} during daisy chain operation, expected'
' -1 <= size <= {} bytes.'.format(size, _MAX_ALLOWED_READ_SIZE))
if size == -1:
# This indicates that we have to read the entire object at once.
if self._end_position <= _MAX_ALLOWED_READ_SIZE:
chunk_size = self._end_position
else:
raise errors.Error('Read with size=-1 is not allowed for object'
' size > {} bytes to prevent reading large objects'
' in-memory.'.format(_MAX_ALLOWED_READ_SIZE))
else:
chunk_size = size
result = io.BytesIO()
bytes_read = 0
while bytes_read < chunk_size and self._position < self._end_position:
if not self._unused_data_from_previous_read:
with self._buffer_condition:
while not self._buffer_queue and not self._shutdown_event.is_set():
self._buffer_condition.wait()
# The shutdown_event needs to be checked before the data is fetched
# from the buffer.
if self._shutdown_event.is_set():
raise _AbruptShutdownError()
data = self._buffer_queue.popleft()
self._buffer_condition.notify_all()
else:
# Data is already present from previous read.
if self._shutdown_event.is_set():
raise _AbruptShutdownError()
data = self._unused_data_from_previous_read
if bytes_read + len(data) > chunk_size:
self._unused_data_from_previous_read = data[chunk_size - bytes_read:]
data_to_return = data[:chunk_size - bytes_read]
else:
self._unused_data_from_previous_read = b''
data_to_return = data
result.write(data_to_return)
bytes_read += len(data_to_return)
self._position += len(data_to_return)
result_data = result.getvalue()
if result_data and self._progress_callback:
self._bytes_read_since_last_progress_callback += len(result_data)
if (self._bytes_read_since_last_progress_callback >=
_PROGRESS_CALLBACK_THRESHOLD):
self._bytes_read_since_last_progress_callback = 0
self._progress_callback(self._position)
return result_data
def seek(self, offset, whence=os.SEEK_SET):
"""Seek to the given offset position.
Ideally, seek changes the stream position to the given byte offset.
But we only handle resumable retry for S3 to GCS transfers at this time,
which means, seek will be called only by the Apitools library.
Since Apitools calls seek only for limited cases, we avoid implementing
seek for all possible cases here in order to avoid unnecessary complexity
in the code.
Following are the cases where Apitools calls seek:
1) At the end of the transfer
https://github.com/google/apitools/blob/ca2094556531d61e741dc2954fdfccbc650cdc32/apitools/base/py/transfer.py#L986
to determine if it has read everything from the stream.
2) For any transient errors during uploads to seek back to a particular
position. This call is always made with whence == os.SEEK_SET.
Args:
offset (int): Defines the position realative to the `whence` where the
current position of the stream should be moved.
whence (int): The reference relative to which offset is interpreted.
Values for whence are: os.SEEK_SET or 0 - start of the stream
(thedefault). os.SEEK_END or 2 - end of the stream. We do not support
other os.SEEK_* constants.
Returns:
(int) The current position.
Raises:
Error:
If seek is called with whence == os.SEEK_END for offset not
equal to the last position.
If seek is called with whence == os.SEEK_CUR.
"""
if whence == os.SEEK_END:
if offset:
raise errors.Error(
'Non-zero offset from os.SEEK_END is not allowed.'
'Offset: {}.'.format(offset)
)
elif whence == os.SEEK_SET:
# Relative to the start of the stream, the offset should be the size
# of the stream
if offset != self._position:
self._restart_download(offset)
else:
raise errors.Error(
'Seek is only supported for os.SEEK_END and os.SEEK_SET.'
)
return self._position
def seekable(self):
"""Returns True if the stream should be treated as a seekable stream."""
return self._seekable
def tell(self):
"""Returns the current position."""
return self._position
def close(self):
"""Updates progress callback if needed."""
if self._is_closed:
# Ensures that close called multiple times does not have any side-effect.
return
if (self._progress_callback and
(self._bytes_read_since_last_progress_callback or
# Update progress for zero-sized object.
self._end_position == 0)):
self._bytes_read_since_last_progress_callback = 0
self._progress_callback(self._position)
self._is_closed = True
class BufferController:
"""Manages a bidirectional buffer to read and write simultaneously.
Attributes:
buffer_queue (collections.deque): The underlying queue that acts like a
buffer for the streams
buffer_condition (threading.Condition): The condition object used for
waiting based on the underlying buffer_queue state.
All threads waiting on this condition are notified when data is added or
removed from buffer_queue. Streams that write to the buffer wait on this
condition until the buffer has space, and streams that read from the
buffer wait on this condition until the buffer has data.
shutdown_event (threading.Event): Used for signaling the operations to
terminate.
writable_stream (_WritableStream): Stream that writes to the buffer.
readable_stream (_ReadableStream): Stream that reads from the buffer.
exception_raised (Exception): Stores the Exception instance responsible for
termination of the operation.
"""
def __init__(self, source_resource, destination_scheme,
user_request_args=None,
progress_callback=None):
"""Initializes BufferController.
Args:
source_resource (resource_reference.ObjectResource): Must
contain the full object path of existing object.
destination_scheme (storage_url.ProviderPrefix): The destination provider.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
progress_callback (progress_callbacks.FilesAndBytesProgressCallback):
Accepts processed bytes and submits progress info for aggregation.
"""
self._source_resource = source_resource
self._user_request_args = user_request_args
self.buffer_queue = collections.deque()
self.buffer_condition = threading.Condition()
self.shutdown_event = threading.Event()
self.writable_stream = _WritableStream(self.buffer_queue,
self.buffer_condition,
self.shutdown_event)
destination_capabilities = api_factory.get_capabilities(destination_scheme)
self.readable_stream = _ReadableStream(
self.buffer_queue,
self.buffer_condition,
self.shutdown_event,
self._source_resource.size,
restart_download_callback=self.restart_download,
progress_callback=progress_callback,
seekable=(cloud_api.Capability.DAISY_CHAIN_SEEKABLE_UPLOAD_STREAM
in destination_capabilities))
self._download_thread = None
self.exception_raised = None
def _get_source_user_request_args_for_download(self):
"""Returns a modified copy of user_request_args for the download request.
When performing a daisy-chain copy (e.g., S3 to GCS, or GCS to S3,
or GCS to GCS), certain flags like custom contexts are intended for the
destination and are unsupported by the source.
For example, object contexts are supported by GCS, but not by S3, so while
performing a daisy-chain copy from S3 to GCS, the object contexts specified
in the user_request_args (intended for the destination) should not be
passed to create the request config for the source, as it would result in an
error.
This method creates a copy of the user_request_args and
removes such destination-intended specific flags before initiating
the download from the source.
"""
if not self._user_request_args or not self._user_request_args.resource_args:
return self._user_request_args
user_args = copy.deepcopy(self._user_request_args)
resource_args = user_args.resource_args
# While doing daisy chain, these arguments are specified for the destination
# resource, and not for the source resource. So, we need to set them to
# None for S3.
setattr(resource_args, 'custom_contexts_to_set', None)
setattr(resource_args, 'custom_contexts_to_remove', None)
setattr(resource_args, 'custom_contexts_to_update', None)
return user_args
def _run_download(self, start_byte):
"""Performs the download operation."""
request_config = request_config_factory.get_request_config(
self._source_resource.storage_url,
user_request_args=self._get_source_user_request_args_for_download())
client = api_factory.get_api(self._source_resource.storage_url.scheme)
try:
if self._source_resource.size != 0:
client.download_object(
self._source_resource,
self.writable_stream,
request_config,
start_byte=start_byte,
download_strategy=cloud_api.DownloadStrategy.ONE_SHOT)
except _AbruptShutdownError:
# Shutdown caused by interruption from another thread.
pass
except Exception as e: # pylint: disable=broad-except
# The stack trace of the exception raised in the thread is not visible
# in the caller thread. Hence we catch any exception so that we can
# re-raise them from the parent thread.
self.shutdown(e)
def start_download_thread(self, start_byte=0):
self._download_thread = threading.Thread(target=self._run_download,
args=(start_byte,))
self._download_thread.start()
def wait_for_download_thread_to_terminate(self):
if self._download_thread is not None:
self._download_thread.join()
def restart_download(self, start_byte):
"""Restarts the download_thread.
Args:
start_byte (int): The start byte for the new download call.
"""
# Signal the download to end.
self.shutdown_event.set()
with self.buffer_condition:
self.buffer_condition.notify_all()
self.wait_for_download_thread_to_terminate()
# Clear all the data in the underlying buffer.
self.buffer_queue.clear()
# Reset the shutdown signal.
self.shutdown_event.clear()
self.start_download_thread(start_byte)
def shutdown(self, error):
"""Sets the shutdown event and stores the error to re-raise later.
Args:
error (Exception): The error responsible for triggering shutdown.
"""
self.shutdown_event.set()
with self.buffer_condition:
self.buffer_condition.notify_all()
self.exception_raised = error
class DaisyChainCopyTask(copy_util.ObjectCopyTaskWithExitHandler):
"""Represents an operation to copy by downloading and uploading.
This task downloads from one cloud location and uplaods to another cloud
location by keeping an in-memory buffer.
"""
def __init__(
self,
source_resource,
destination_resource,
delete_source=False,
fetch_source_fields_scope=None,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (resource_reference.ObjectResource): Must contain the full
object path of existing object. Directories will not be accepted.
destination_resource (resource_reference.UnknownResource): Must contain
the full object path. Object may not exist yet. Existing objects at the
this location will be overwritten. Directories will not be accepted.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
fetch_source_fields_scope (FieldsScope|None): If present, then refetch
source_resource with metadata determined by this FieldsScope.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
user_request_args (UserRequestArgs|None): See parent class.
verbose (bool): See parent class.
"""
super(DaisyChainCopyTask, self).__init__(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
if (not isinstance(source_resource.storage_url, storage_url.CloudUrl)
or not isinstance(destination_resource.storage_url,
storage_url.CloudUrl)):
raise errors.Error(
'DaisyChainCopyTask is for copies between cloud providers.'
)
self._fetch_source_fields_scope = fetch_source_fields_scope
self._delete_source = delete_source
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string)
def _get_md5_hash(self):
"""Returns the MD5 Hash if present and hash validation is requested."""
if (properties.VALUES.storage.check_hashes.Get() ==
properties.CheckHashes.NEVER.value):
return None
if self._enriched_source_resource.md5_hash is None:
# For composite uploads, MD5 hash might be missing.
# TODO(b/191975989) Add support for crc32c once -D option is implemented.
# Composite uploads will have crc32c information, which we should
# pass to the request.
log.warning(
'Found no hashes to validate object downloaded from %s and'
' uploaded to %s. Integrity cannot be assured without hashes.',
self._enriched_source_resource, self._destination_resource)
return self._enriched_source_resource.md5_hash
def _gapfill_request_config_field(self, resource_args,
request_config_field_name,
source_resource_field_name):
request_config_value = getattr(resource_args, request_config_field_name,
None)
if request_config_value is None:
setattr(
resource_args,
request_config_field_name,
getattr(
self._enriched_source_resource, source_resource_field_name
),
)
def _populate_request_config_with_resource_values(self, request_config):
resource_args = request_config.resource_args
# Does not cover all fields. Just the ones gsutil does.
self._gapfill_request_config_field(resource_args, 'cache_control',
'cache_control')
self._gapfill_request_config_field(resource_args, 'content_disposition',
'content_disposition')
self._gapfill_request_config_field(resource_args, 'content_encoding',
'content_encoding')
self._gapfill_request_config_field(resource_args, 'content_language',
'content_language')
self._gapfill_request_config_field(resource_args, 'content_type',
'content_type')
self._gapfill_request_config_field(resource_args, 'custom_time',
'custom_time')
self._gapfill_request_config_field(resource_args, 'md5_hash',
'md5_hash')
# Storage class is intentionally excluded here, since gsutil uses the
# bucket's default for daisy chain destinations:
# https://github.com/GoogleCloudPlatform/gsutil/blob/db22c6cf44e4f58a56864f0a6f9bcdf868a3c156/gslib/utils/copy_helper.py#L3860
def execute(self, task_status_queue=None):
"""Copies file by downloading and uploading in parallel."""
# TODO (b/168712813): Add option to use the Data Transfer component.
# We only preserve metadata for S3 to GCS syncs, and not for GCS to S3.
# Note that GCS to GCS, and S3 to S3 rsync only follows intra-cloud metadata
# preservation logic, and not daisy chain logic. Rsync does not support
# --daisy-chain flag, so we don't need to worry about it here.
# Additionally cp, mv doesn't require re-fetching source metadata, as this
# is only required for rsync usecases due to lost attributes during
# comparison alogrithm, so we don't need to worry about it here too.
if self._fetch_source_fields_scope and (
self._source_resource.storage_url.scheme
is storage_url.ProviderPrefix.S3
):
# Update source_resource with metadata if fetch_source_fields_scope.
source_client = api_factory.get_api(
self._source_resource.storage_url.scheme
)
self._enriched_source_resource = source_client.get_object_metadata(
self._source_resource.bucket,
self._source_resource.name,
generation=self._source_resource.generation,
fields_scope=self._fetch_source_fields_scope,
)
else:
self._enriched_source_resource = self._source_resource
destination_client = api_factory.get_api(
self._destination_resource.storage_url.scheme)
if copy_util.check_for_cloud_clobber(self._user_request_args,
destination_client,
self._destination_resource):
log.status.Print(
copy_util.get_no_clobber_message(
self._destination_resource.storage_url))
if self._send_manifest_messages:
manifest_util.send_skip_message(
task_status_queue, self._enriched_source_resource,
self._destination_resource,
copy_util.get_no_clobber_message(
self._destination_resource.storage_url))
return
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=0,
length=self._enriched_source_resource.size,
source_url=self._enriched_source_resource.storage_url,
destination_url=self._destination_resource.storage_url,
operation_name=task_status.OperationName.DAISY_CHAIN_COPYING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
buffer_controller = BufferController(
self._enriched_source_resource,
self._destination_resource.storage_url.scheme,
self._user_request_args,
progress_callback)
# Perform download in a separate thread so that upload can be performed
# simultaneously.
buffer_controller.start_download_thread()
content_type = (
self._enriched_source_resource.content_type or
request_config_factory.DEFAULT_CONTENT_TYPE)
request_config = request_config_factory.get_request_config(
self._destination_resource.storage_url,
content_type=content_type,
md5_hash=self._get_md5_hash(),
size=self._enriched_source_resource.size,
user_request_args=self._user_request_args)
# Request configs are designed to translate between providers.
self._populate_request_config_with_resource_values(request_config)
result_resource = None
try:
upload_strategy = upload_util.get_upload_strategy(
api=destination_client,
object_length=self._enriched_source_resource.size)
result_resource = destination_client.upload_object(
buffer_controller.readable_stream,
self._destination_resource,
request_config,
posix_to_set=self._posix_to_set,
source_resource=self._enriched_source_resource,
upload_strategy=upload_strategy,
)
except _AbruptShutdownError:
# Not raising daisy_chain_stream.exception_raised here because we want
# to wait for the download thread to finish.
pass
except Exception as e: # pylint: disable=broad-except
# For all the other errors raised during upload, we want to to make
# sure that the download thread is terminated before we re-reaise.
# Hence we catch any exception and store it to be re-raised later.
buffer_controller.shutdown(e)
buffer_controller.wait_for_download_thread_to_terminate()
buffer_controller.readable_stream.close()
if buffer_controller.exception_raised:
raise buffer_controller.exception_raised
if result_resource:
self._print_created_message_if_requested(result_resource)
if self._send_manifest_messages:
manifest_util.send_success_message(
task_status_queue,
self._enriched_source_resource,
self._destination_resource,
md5_hash=result_resource.md5_hash)
if self._delete_source:
return task.Output(
additional_task_iterators=[[
delete_task.DeleteObjectTask(
self._enriched_source_resource.storage_url
)
]],
messages=None,
)

View File

@@ -0,0 +1,140 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deletes temporary components and tracker files from a composite upload."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import glob
import os
from googlecloudsdk.api_lib.storage import errors as api_errors
from googlecloudsdk.command_lib.storage import errors as command_errors
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks.cp import copy_component_util
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from googlecloudsdk.core import log
def _try_delete_and_return_permissions_error(component_url):
"""Attempts deleting component and returns any permissions errors."""
try:
delete_task.DeleteObjectTask(component_url, verbose=False).execute()
except api_errors.CloudApiError as e:
status = getattr(e, 'status_code', None)
if status == 403:
return e
raise
class DeleteTemporaryComponentsTask(task.Task):
"""Deletes temporary components and tracker files after a composite upload."""
def __init__(self, source_resource, destination_resource, random_prefix):
"""Initializes a task instance.
Args:
source_resource (resource_reference.FileObjectResource): The local,
uploaded file.
destination_resource (resource_reference.UnknownResource): The final
composite object's metadata.
random_prefix (str): ID added to temporary component names.
"""
super(DeleteTemporaryComponentsTask, self).__init__()
self._source_resource = source_resource
self._destination_resource = destination_resource
self._random_prefix = random_prefix
def execute(self, task_status_queue=None):
"""Deletes temporary components and associated tracker files.
Args:
task_status_queue: See base class.
Returns:
A task.Output with tasks for deleting temporary components.
"""
del task_status_queue
component_tracker_path_prefix = tracker_file_util.get_tracker_file_path(
copy_component_util.get_temporary_component_resource(
self._source_resource, self._destination_resource,
self._random_prefix, component_id='').storage_url,
tracker_file_util.TrackerFileType.UPLOAD,
# TODO(b/190093425): Setting component_number will not be necessary
# after using the final destination to generate component tracker paths.
component_number='')
# Matches all paths, regardless of component number:
component_tracker_paths = glob.iglob(component_tracker_path_prefix + '*')
component_urls = []
found_permissions_error = permissions_error = None
for component_tracker_path in component_tracker_paths:
tracker_data = tracker_file_util.read_resumable_upload_tracker_file(
component_tracker_path)
if tracker_data.complete:
_, _, component_number = component_tracker_path.rpartition('_')
component_url = (
copy_component_util.get_temporary_component_resource(
self._source_resource,
self._destination_resource,
self._random_prefix,
component_id=component_number).storage_url)
if found_permissions_error is None:
permissions_error = _try_delete_and_return_permissions_error(
component_url)
found_permissions_error = permissions_error is not None
if found_permissions_error:
# Save URL for error message.
component_urls.append(component_url)
else:
# Save URL to delete with task later.
component_urls.append(component_url)
os.remove(component_tracker_path)
if permissions_error:
log.error(
'Parallel composite upload failed: Permissions error detected while'
' attempting to delete object component.'
'\n\nTo disable parallel composite uploads, run:'
'\ngcloud config set storage/parallel_composite_upload_enabled False'
'\n\nTo delete the temporary objects left over by this command,'
' switch to an account with appropriate permissions and run:'
'\ngcloud storage rm {}'.format(' '.join(
[url.url_string for url in component_urls])))
raise command_errors.FatalError(permissions_error)
if component_urls:
additional_task_iterators = [
[
delete_task.DeleteObjectTask(url, verbose=False)
for url in component_urls
]
]
return task.Output(
additional_task_iterators=additional_task_iterators, messages=None)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._random_prefix == other._random_prefix
)

View File

@@ -0,0 +1,147 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for performing download operation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import gzip_util
from googlecloudsdk.command_lib.storage import hash_util
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage import symlink_util
from googlecloudsdk.command_lib.storage import tracker_file_util
SYMLINK_TEMPORARY_PLACEHOLDER_SUFFIX = '_sym'
def _decompress_or_rename_file(
source_resource,
temporary_file_path,
final_file_path,
do_not_decompress_flag=False,
server_encoding=None,
):
"""Converts temporary file to final form by decompressing or renaming.
Args:
source_resource (ObjectResource): May contain encoding metadata.
temporary_file_path (str): File path to unzip or rename.
final_file_path (str): File path to write final file to.
do_not_decompress_flag (bool): User flag that blocks decompression.
server_encoding (str|None): Server-reported `content-encoding` of file.
Returns:
(bool) True if file was decompressed or renamed, and
False if file did not exist.
"""
if not os.path.exists(temporary_file_path):
return False
if gzip_util.decompress_gzip_if_necessary(source_resource,
temporary_file_path,
final_file_path,
do_not_decompress_flag,
server_encoding):
os.remove(temporary_file_path)
else:
os.rename(temporary_file_path, final_file_path)
return True
def finalize_download(
source_resource,
temporary_file_path,
final_file_path,
do_not_decompress_flag=False,
server_encoding=None,
convert_symlinks=False,
):
"""Converts temporary file to final form.
This may involve decompressing, renaming, and/or converting symlink
placeholders to actual symlinks.
Args:
source_resource (ObjectResource): May contain encoding metadata.
temporary_file_path (str): File path to unzip or rename.
final_file_path (str): File path to write final file to.
do_not_decompress_flag (bool): User flag that blocks decompression.
server_encoding (str|None): Server-reported `content-encoding` of file.
convert_symlinks (bool): Whether symlink placeholders should be converted to
actual symlinks.
Returns:
(bool) True if file was decompressed, renamed, and/or converted to a
symlink; False if file did not exist.
"""
make_symlink = convert_symlinks and source_resource.is_symlink
if make_symlink:
# The decompressed/renamed content is a symlink placeholder, so store it as
# as a temporary placeholder alongside the original temporary_file_path.
decompress_or_rename_path = (temporary_file_path +
SYMLINK_TEMPORARY_PLACEHOLDER_SUFFIX)
else:
decompress_or_rename_path = final_file_path
decompress_or_rename_result = _decompress_or_rename_file(
source_resource=source_resource,
temporary_file_path=temporary_file_path,
final_file_path=decompress_or_rename_path,
do_not_decompress_flag=do_not_decompress_flag,
server_encoding=server_encoding,
)
if not decompress_or_rename_result:
return False
if make_symlink:
symlink_util.create_symlink_from_temporary_placeholder(
placeholder_path=decompress_or_rename_path, symlink_path=final_file_path
)
os.remove(decompress_or_rename_path)
return decompress_or_rename_result
def validate_download_hash_and_delete_corrupt_files(download_path, source_hash,
destination_hash):
"""Confirms hashes match for copied objects.
Args:
download_path (str): URL of object being validated.
source_hash (str): Hash of source object.
destination_hash (str): Hash of downloaded object.
Raises:
HashMismatchError: Hashes are not equal.
"""
try:
hash_util.validate_object_hashes_match(download_path, source_hash,
destination_hash)
except errors.HashMismatchError:
os.remove(download_path)
tracker_file_util.delete_download_tracker_files(
storage_url.storage_url_from_string(download_path))
raise
def return_and_report_if_nothing_to_download(cloud_resource, progress_callback):
"""Returns valid download range bool and reports progress if not."""
if cloud_resource.size == 0:
if progress_callback:
progress_callback(0)
return True
return False

View File

@@ -0,0 +1,326 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for file downloads.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
import os
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.command_lib.storage import fast_crc32c_util
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import posix_util
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage import symlink_util
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_util
from googlecloudsdk.command_lib.storage.tasks.cp import copy_component_util
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.cp import download_util
from googlecloudsdk.command_lib.storage.tasks.cp import file_part_download_task
from googlecloudsdk.command_lib.storage.tasks.cp import finalize_sliced_download_task
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import scaled_integer
def _should_perform_sliced_download(source_resource, destination_resource):
"""Returns True if conditions are right for a sliced download."""
if destination_resource.storage_url.is_stream:
# Can't write to different indices of streams.
return False
if (
not source_resource.crc32c_hash
and properties.VALUES.storage.check_hashes.Get()
!= properties.CheckHashes.NEVER.value
):
# Do not perform sliced download if hash validation is not possible.
return False
threshold = scaled_integer.ParseInteger(
properties.VALUES.storage.sliced_object_download_threshold.Get()
)
component_size = scaled_integer.ParseInteger(
properties.VALUES.storage.sliced_object_download_component_size.Get()
)
args = [source_resource.storage_url.scheme]
if properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool():
args.append(source_resource.storage_url.bucket_name)
api_capabilities = api_factory.get_capabilities(*args)
return (
source_resource.size
and threshold != 0
and source_resource.size > threshold
and component_size
and cloud_api.Capability.SLICED_DOWNLOAD in api_capabilities
and task_util.should_use_parallelism()
)
class FileDownloadTask(copy_util.ObjectCopyTaskWithExitHandler):
"""Represents a command operation triggering a file download."""
def __init__(
self,
source_resource,
destination_resource,
delete_source=False,
do_not_decompress=False,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
system_posix_data=None,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (ObjectResource): Must contain the full path of object to
download, including bucket. Directories will not be accepted. Does not
need to contain metadata.
destination_resource (FileObjectResource|UnknownResource): Must contain
local filesystem path to destination object. Does not need to contain
metadata.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
do_not_decompress (bool): Prevents automatically decompressing downloaded
gzips.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
system_posix_data (SystemPosixData): System-wide POSIX info.
user_request_args (UserRequestArgs|None): See parent class..
verbose (bool): See parent class.
"""
super(FileDownloadTask, self).__init__(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
self._delete_source = delete_source
self._do_not_decompress = do_not_decompress
self._system_posix_data = system_posix_data
self._temporary_destination_resource = (
self._get_temporary_destination_resource())
if (self._source_resource.size and
self._source_resource.size >= scaled_integer.ParseInteger(
properties.VALUES.storage.resumable_threshold.Get())):
self._strategy = cloud_api.DownloadStrategy.RESUMABLE
else:
self._strategy = cloud_api.DownloadStrategy.RETRIABLE_IN_FLIGHT
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string)
def _get_temporary_destination_resource(self):
temporary_resource = copy.deepcopy(self._destination_resource)
temporary_resource.storage_url.resource_name += (
storage_url.TEMPORARY_FILE_SUFFIX)
return temporary_resource
def _get_sliced_download_tasks(self):
"""Creates all tasks necessary for a sliced download."""
component_offsets_and_lengths = (
copy_component_util.get_component_offsets_and_lengths(
self._source_resource.size,
copy_component_util.get_component_count(
self._source_resource.size,
properties.VALUES.storage.sliced_object_download_component_size
.Get(),
properties.VALUES.storage.sliced_object_download_max_components
.GetInt())))
download_component_task_list = []
for i, (offset, length) in enumerate(component_offsets_and_lengths):
download_component_task_list.append(
file_part_download_task.FilePartDownloadTask(
self._source_resource,
self._temporary_destination_resource,
offset=offset,
length=length,
component_number=i,
total_components=len(component_offsets_and_lengths),
strategy=self._strategy,
user_request_args=self._user_request_args))
finalize_sliced_download_task_list = [
finalize_sliced_download_task.FinalizeSlicedDownloadTask(
self._source_resource,
self._temporary_destination_resource,
self._destination_resource,
delete_source=self._delete_source,
do_not_decompress=self._do_not_decompress,
posix_to_set=self._posix_to_set,
system_posix_data=self._system_posix_data,
user_request_args=self._user_request_args,
)
]
return (download_component_task_list, finalize_sliced_download_task_list)
def _restart_download(self):
log.status.Print('Temporary download file corrupt.'
' Restarting download {}'.format(self._source_resource))
temporary_download_url = self._temporary_destination_resource.storage_url
os.remove(temporary_download_url.resource_name)
tracker_file_util.delete_download_tracker_files(temporary_download_url)
def execute(self, task_status_queue=None):
"""Creates appropriate download tasks."""
posix_util.run_if_setting_posix(
self._posix_to_set,
self._user_request_args,
posix_util.raise_if_invalid_file_permissions,
self._system_posix_data,
self._source_resource,
known_posix=self._posix_to_set,
)
destination_url = self._destination_resource.storage_url
# We need to call os.remove here for two reasons:
# 1. It saves on disk space during a transfer.
# 2. Os.rename fails if a file exists at the destination. Avoiding this by
# removing files after a download makes us susceptible to a race condition
# between two running instances of gcloud storage. See the following PR for
# more information: https://github.com/GoogleCloudPlatform/gsutil/pull/1202.
# Note that it's not enough to check the results of `exists()`, since that
# method returns False if the path points to a broken symlink.
is_destination_symlink = os.path.islink(destination_url.resource_name)
if is_destination_symlink or destination_url.exists():
if self._user_request_args and self._user_request_args.no_clobber:
log.status.Print(copy_util.get_no_clobber_message(destination_url))
if self._send_manifest_messages:
manifest_util.send_skip_message(
task_status_queue, self._source_resource,
self._destination_resource,
copy_util.get_no_clobber_message(destination_url))
return
os.remove(destination_url.resource_name)
temporary_download_file_exists = (
self._temporary_destination_resource.storage_url.exists())
if temporary_download_file_exists and os.path.getsize(
self._temporary_destination_resource.storage_url.resource_name
) > self._source_resource.size:
self._restart_download()
if _should_perform_sliced_download(self._source_resource,
self._destination_resource):
fast_crc32c_util.log_or_raise_crc32c_issues()
download_component_task_list, finalize_sliced_download_task_list = (
self._get_sliced_download_tasks())
_, found_tracker_file = (
tracker_file_util.read_or_create_download_tracker_file(
self._source_resource,
self._temporary_destination_resource.storage_url,
total_components=len(download_component_task_list),
))
if found_tracker_file:
log.debug('Resuming sliced download with {} components.'.format(
len(download_component_task_list)))
else:
if temporary_download_file_exists:
# Component count may have changed, invalidating earlier download.
self._restart_download()
log.debug('Launching sliced download with {} components.'.format(
len(download_component_task_list)))
copy_component_util.create_file_if_needed(
self._source_resource, self._temporary_destination_resource)
return task.Output(
additional_task_iterators=[
download_component_task_list,
finalize_sliced_download_task_list,
],
messages=None)
part_download_task_output = file_part_download_task.FilePartDownloadTask(
self._source_resource,
self._temporary_destination_resource,
offset=0,
length=self._source_resource.size,
do_not_decompress=self._do_not_decompress,
strategy=self._strategy,
user_request_args=self._user_request_args,
).execute(task_status_queue=task_status_queue)
temporary_file_url = self._temporary_destination_resource.storage_url
server_encoding = task_util.get_first_matching_message_payload(
part_download_task_output.messages, task.Topic.API_DOWNLOAD_RESULT
)
preserve_symlinks = symlink_util.get_preserve_symlink_from_user_request(
self._user_request_args
)
download_util.finalize_download(
self._source_resource,
temporary_file_url.resource_name,
destination_url.resource_name,
convert_symlinks=preserve_symlinks,
do_not_decompress_flag=self._do_not_decompress,
server_encoding=server_encoding,
)
# For sliced download, cleanup is done in the finalized sliced download task
# We perform cleanup here for all other types in case some corrupt files
# were left behind.
tracker_file_util.delete_download_tracker_files(temporary_file_url)
posix_util.run_if_setting_posix(
self._posix_to_set,
self._user_request_args,
posix_util.set_posix_attributes_on_file_if_valid,
self._system_posix_data,
self._source_resource,
self._destination_resource,
known_source_posix=self._posix_to_set,
preserve_symlinks=preserve_symlinks,
)
self._print_created_message_if_requested(self._destination_resource)
if self._send_manifest_messages:
manifest_util.send_success_message(
task_status_queue,
self._source_resource,
self._destination_resource,
md5_hash=task_util.get_first_matching_message_payload(
part_download_task_output.messages, task.Topic.MD5))
if self._delete_source:
return task.Output(
additional_task_iterators=[[
delete_task.DeleteObjectTask(self._source_resource.storage_url),
]],
messages=None,
)

View File

@@ -0,0 +1,427 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for file downloads.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import fast_crc32c_util
from googlecloudsdk.command_lib.storage import hash_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.cp import copy_component_util
from googlecloudsdk.command_lib.storage.tasks.cp import download_util
from googlecloudsdk.command_lib.storage.tasks.cp import file_part_task
from googlecloudsdk.command_lib.util import crc32c
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import files
from googlecloudsdk.core.util import hashing
_READ_SIZE = 8192 # 8 KiB.
NULL_BYTE = b'\x00'
def _get_first_null_byte_index(destination_url, offset, length):
"""Checks to see how many bytes in range have already been downloaded.
Args:
destination_url (storage_url.FileUrl): Has path of file being downloaded.
offset (int): For components, index to start reading bytes at.
length (int): For components, where to stop reading bytes.
Returns:
Int byte count of size of partially-downloaded file. Returns 0 if file is
an invalid size, empty, or non-existent.
"""
if not destination_url.exists():
return 0
# Component is slice of larger file. Find how much of slice is downloaded.
first_null_byte = offset
end_of_range = offset + length
with files.BinaryFileReader(destination_url.resource_name) as file_reader:
file_reader.seek(offset)
while first_null_byte < end_of_range:
data = file_reader.read(_READ_SIZE)
if not data:
break
null_byte_index = data.find(NULL_BYTE)
if null_byte_index != -1:
first_null_byte += null_byte_index
break
first_null_byte += len(data)
return first_null_byte
def _get_digesters(component_number, resource):
"""Returns digesters dictionary for download hash validation.
Note: The digester object is not picklable. It cannot be passed between
tasks through the task graph.
Args:
component_number (int|None): Used to determine if downloading a slice in a
sliced download, which uses CRC32C for hashing.
resource (resource_reference.ObjectResource): For checking if object has
known hash to validate against.
Returns:
Digesters dict.
Raises:
errors.Error: gcloud storage set to fail if performance-optimized digesters
could not be created.
"""
digesters = {}
check_hashes = properties.VALUES.storage.check_hashes.Get()
if check_hashes == properties.CheckHashes.NEVER.value:
return digesters
if component_number is None and resource.md5_hash:
digesters[hash_util.HashAlgorithm.MD5] = hashing.get_md5()
elif resource.crc32c_hash and (
check_hashes == properties.CheckHashes.ALWAYS.value
or fast_crc32c_util.check_if_will_use_fast_crc32c(
install_if_missing=True
)
):
digesters[hash_util.HashAlgorithm.CRC32C] = fast_crc32c_util.get_crc32c()
if not digesters:
log.warning(
'Found no hashes to validate download of object: %s. Component number:'
' %s. Integrity cannot be assured without hashes.',
resource,
component_number,
)
return digesters
class FilePartDownloadTask(file_part_task.FilePartTask):
"""Downloads a byte range."""
def __init__(self,
source_resource,
destination_resource,
offset,
length,
component_number=None,
total_components=None,
do_not_decompress=False,
strategy=cloud_api.DownloadStrategy.RETRIABLE_IN_FLIGHT,
user_request_args=None):
"""Initializes task.
Args:
source_resource (resource_reference.ObjectResource): Must contain the full
path of object to download, including bucket. Directories will not be
accepted. Does not need to contain metadata.
destination_resource (resource_reference.FileObjectResource): Must contain
local filesystem path to upload object. Does not need to contain
metadata.
offset (int): The index of the first byte in the upload range.
length (int): The number of bytes in the upload range.
component_number (int|None): If a multipart operation, indicates the
component number.
total_components (int|None): If a multipart operation, indicates the total
number of components.
do_not_decompress (bool): Prevents automatically decompressing
downloaded gzips.
strategy (cloud_api.DownloadStrategy): Determines what download
implementation to use.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
"""
super(FilePartDownloadTask,
self).__init__(source_resource, destination_resource, offset, length,
component_number, total_components)
self._do_not_decompress_flag = do_not_decompress
self._strategy = strategy
self._user_request_args = user_request_args
def _calculate_deferred_hashes(self, digesters):
"""DeferredCrc32c does not hash on-the-fly and needs a summation call."""
if isinstance(
digesters.get(hash_util.HashAlgorithm.CRC32C),
fast_crc32c_util.DeferredCrc32c,
):
digesters[hash_util.HashAlgorithm.CRC32C].sum_file(
self._destination_resource.storage_url.resource_name,
self._offset,
self._length,
)
def _disable_in_flight_decompression(self, is_resumable_or_sliced_download):
"""Whether or not to disable on-the-fly decompression."""
if self._do_not_decompress_flag:
# Respect user preference.
return True
if not is_resumable_or_sliced_download:
# If we don't decompress in-flight, we'll do it later on the disk, which
# is probably slower. However, the requests library might add the
# "accept-encoding: gzip" header anyways.
return False
# Decompressing in flight changes file size, making resumable and sliced
# downloads impossible.
return bool(self._source_resource.content_encoding and
'gzip' in self._source_resource.content_encoding)
def _perform_download(self, request_config, progress_callback,
do_not_decompress, download_strategy, start_byte,
end_byte, write_mode, digesters):
"""Prepares file stream, calls API, and validates hash."""
with files.BinaryFileWriter(
self._destination_resource.storage_url.resource_name,
create_path=True,
mode=write_mode,
convert_invalid_windows_characters=(
properties.VALUES.storage
.convert_incompatible_windows_path_characters.GetBool()
)) as download_stream:
download_stream.seek(start_byte)
provider = self._source_resource.storage_url.scheme
enable_zonal_buckets_bidi_streaming = (
properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool()
)
if enable_zonal_buckets_bidi_streaming:
bucket_name = self._source_resource.storage_url.bucket_name
api = api_factory.get_api(provider, bucket_name=bucket_name)
else:
api = api_factory.get_api(provider)
# TODO(b/162264437): Support all of download_object's parameters.
api_download_result = api.download_object(
self._source_resource,
download_stream,
request_config,
digesters=digesters,
do_not_decompress=do_not_decompress,
download_strategy=download_strategy,
progress_callback=progress_callback,
start_byte=start_byte,
end_byte=end_byte)
self._calculate_deferred_hashes(digesters)
if hash_util.HashAlgorithm.MD5 in digesters:
calculated_digest = hash_util.get_base64_hash_digest_string(
digesters[hash_util.HashAlgorithm.MD5])
download_util.validate_download_hash_and_delete_corrupt_files(
self._destination_resource.storage_url.resource_name,
self._source_resource.md5_hash, calculated_digest)
elif hash_util.HashAlgorithm.CRC32C in digesters:
# Only for one-shot composite object downloads as final CRC32C validated
# in FinalizeSlicedDownloadTask.
if self._component_number is None:
calculated_digest = crc32c.get_hash(
digesters[hash_util.HashAlgorithm.CRC32C])
download_util.validate_download_hash_and_delete_corrupt_files(
self._destination_resource.storage_url.resource_name,
self._source_resource.crc32c_hash, calculated_digest)
return api_download_result
def _perform_retriable_download(self, request_config, progress_callback,
digesters):
"""Sets up a basic download based on task attributes."""
start_byte = self._offset
end_byte = self._offset + self._length - 1
return self._perform_download(
request_config, progress_callback,
self._disable_in_flight_decompression(False),
cloud_api.DownloadStrategy.RETRIABLE_IN_FLIGHT, start_byte, end_byte,
files.BinaryFileWriterMode.TRUNCATE, digesters)
def _catch_up_digesters(self, digesters, start_byte, end_byte):
"""Gets hash of partially-downloaded file as start for validation."""
for hash_algorithm in digesters:
if isinstance(digesters[hash_algorithm], fast_crc32c_util.DeferredCrc32c):
# Deferred calculation runs at end, no on-the-fly.
continue
digesters[hash_algorithm] = hash_util.get_hash_from_file(
self._destination_resource.storage_url.resource_name,
hash_algorithm,
start=start_byte,
stop=end_byte,
)
def _perform_resumable_download(self, request_config, progress_callback,
digesters):
"""Resume or start download that can be resumabled."""
copy_component_util.create_file_if_needed(self._source_resource,
self._destination_resource)
destination_url = self._destination_resource.storage_url
first_null_byte = _get_first_null_byte_index(destination_url,
self._offset, self._length)
_, found_tracker_file = (
tracker_file_util.read_or_create_download_tracker_file(
self._source_resource, destination_url))
start_byte = first_null_byte if found_tracker_file else 0
end_byte = self._source_resource.size - 1
if start_byte:
write_mode = files.BinaryFileWriterMode.MODIFY
self._catch_up_digesters(digesters, start_byte=0, end_byte=start_byte)
log.status.Print('Resuming download for {}'.format(self._source_resource))
else:
# TRUNCATE can create new file unlike MODIFY.
write_mode = files.BinaryFileWriterMode.TRUNCATE
return self._perform_download(request_config, progress_callback,
self._disable_in_flight_decompression(True),
cloud_api.DownloadStrategy.RESUMABLE,
start_byte, end_byte, write_mode, digesters)
def _get_output(self, digesters, server_encoding):
"""Generates task.Output from download execution results.
Args:
digesters (dict): Contains hash objects for download checksums.
server_encoding (str|None): Generic information from API client about the
download results.
Returns:
task.Output: Data the parent download or finalize download class would
like to have.
"""
messages = []
if hash_util.HashAlgorithm.MD5 in digesters:
md5_digest = hash_util.get_base64_hash_digest_string(
digesters[hash_util.HashAlgorithm.MD5])
messages.append(task.Message(topic=task.Topic.MD5, payload=md5_digest))
if hash_util.HashAlgorithm.CRC32C in digesters:
crc32c_checksum = crc32c.get_checksum(
digesters[hash_util.HashAlgorithm.CRC32C])
messages.append(
task.Message(
topic=task.Topic.CRC32C,
payload={
'component_number': self._component_number,
'crc32c_checksum': crc32c_checksum,
'length': self._length,
}))
if server_encoding:
messages.append(
task.Message(
topic=task.Topic.API_DOWNLOAD_RESULT, payload=server_encoding
)
)
return task.Output(additional_task_iterators=None, messages=messages)
def _perform_component_download(self, request_config, progress_callback,
digesters):
"""Component download does not validate hash or delete tracker."""
destination_url = self._destination_resource.storage_url
end_byte = self._offset + self._length - 1
if self._strategy == cloud_api.DownloadStrategy.RESUMABLE:
_, found_tracker_file = (
tracker_file_util.read_or_create_download_tracker_file(
self._source_resource,
destination_url,
slice_start_byte=self._offset,
component_number=self._component_number))
first_null_byte = _get_first_null_byte_index(
destination_url, offset=self._offset, length=self._length)
start_byte = first_null_byte if found_tracker_file else self._offset
if start_byte > end_byte:
log.status.Print('{} component {} already downloaded.'.format(
self._source_resource, self._component_number))
self._calculate_deferred_hashes(digesters)
self._catch_up_digesters(
digesters,
start_byte=self._offset,
end_byte=self._source_resource.size)
return
if found_tracker_file and start_byte != self._offset:
self._catch_up_digesters(
digesters, start_byte=self._offset, end_byte=start_byte)
log.status.Print('Resuming download for {} component {}'.format(
self._source_resource, self._component_number))
else:
# For non-resumable sliced downloads.
start_byte = self._offset
return self._perform_download(request_config, progress_callback,
self._disable_in_flight_decompression(True),
self._strategy, start_byte, end_byte,
files.BinaryFileWriterMode.MODIFY, digesters)
def execute(self, task_status_queue=None):
"""Performs download."""
digesters = _get_digesters(self._component_number, self._source_resource)
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=self._offset,
length=self._length,
source_url=self._source_resource.storage_url,
destination_url=self._destination_resource.storage_url,
component_number=self._component_number,
total_components=self._total_components,
operation_name=task_status.OperationName.DOWNLOADING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
request_config = request_config_factory.get_request_config(
self._source_resource.storage_url,
decryption_key_hash_sha256=(
self._source_resource.decryption_key_hash_sha256),
user_request_args=self._user_request_args,
)
if self._source_resource.size and self._component_number is not None:
try:
server_encoding = self._perform_component_download(
request_config, progress_callback, digesters
)
# pylint:disable=broad-except
except Exception as e:
if task_status_queue is not None:
progress_callback(self._offset, error_occurred=True)
# pylint:enable=broad-except
return task.Output(
additional_task_iterators=None,
messages=[task.Message(topic=task.Topic.ERROR, payload=e)])
elif self._strategy is cloud_api.DownloadStrategy.RESUMABLE:
server_encoding = self._perform_resumable_download(
request_config, progress_callback, digesters
)
else:
server_encoding = self._perform_retriable_download(
request_config, progress_callback, digesters
)
return self._get_output(digesters, server_encoding)

View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract task for handling components, slices, or parts of larger files.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
from googlecloudsdk.command_lib.storage.tasks import task
class FilePartTask(task.Task):
"""Abstract class for handling a range of bytes in a file."""
def __init__(self, source_resource, destination_resource, offset, length,
component_number=None, total_components=None):
"""Initializes task.
Args:
source_resource (resource_reference.Resource): Source resource to copy.
destination_resource (resource_reference.Resource): Target resource to
copy to.
offset (int): The index of the first byte in the range.
length (int): The number of bytes in the range.
component_number (int): If a multipart operation, indicates the
component number.
total_components (int): If a multipart operation, indicates the
total number of components.
"""
super(FilePartTask, self).__init__()
self._source_resource = source_resource
self._destination_resource = destination_resource
self._offset = offset
self._length = length
self._component_number = component_number
self._total_components = total_components
@abc.abstractmethod
def execute(self, task_status_queue=None):
pass
def __eq__(self, other):
if not isinstance(other, FilePartTask):
return NotImplemented
return (self._destination_resource == other._destination_resource and
self._source_resource == other._source_resource and
self._offset == other._offset and self._length == other._length
and self._component_number == other._component_number and
self._total_components == other._total_components)

View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for file uploads.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import collections
import functools
import os
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import errors as api_errors
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import encryption_util
from googlecloudsdk.command_lib.storage import errors as command_errors
from googlecloudsdk.command_lib.storage import fast_crc32c_util
from googlecloudsdk.command_lib.storage import hash_util
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks.cp import file_part_task
from googlecloudsdk.command_lib.storage.tasks.cp import upload_util
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import retry
UploadedComponent = collections.namedtuple(
'UploadedComponent',
['component_number', 'object_resource']
)
class FilePartUploadTask(file_part_task.FilePartTask):
"""Uploads a range of bytes from a file."""
def __init__(
self,
source_resource,
destination_resource,
source_path,
offset,
length,
component_number=None,
posix_to_set=None,
total_components=None,
user_request_args=None,
):
"""Initializes task.
Args:
source_resource (resource_reference.FileObjectResource): Must contain
local filesystem path to upload object. Does not need to contain
metadata.
destination_resource (resource_reference.ObjectResource|UnknownResource):
Must contain the full object path. Directories will not be accepted.
Existing objects at the this location will be overwritten.
source_path (str): Path to file to upload. May be the original or a
transformed temporary file.
offset (int): The index of the first byte in the upload range.
length (int): The number of bytes in the upload range.
component_number (int|None): If a multipart operation, indicates the
component number.
posix_to_set (PosixAttributes|None): POSIX info set as custom cloud
metadata on target. If provided and preserving POSIX, skip re-parsing
from file system.
total_components (int|None): If a multipart operation, indicates the total
number of components.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
"""
super(FilePartUploadTask,
self).__init__(source_resource, destination_resource, offset, length,
component_number, total_components)
self._source_path = source_path
self._posix_to_set = posix_to_set
self._user_request_args = user_request_args
self._transformed_source_resource = resource_reference.FileObjectResource(
storage_url.storage_url_from_string(self._source_path))
def _get_output(self, destination_resource):
messages = []
if self._component_number is not None:
messages.append(
task.Message(
topic=task.Topic.UPLOADED_COMPONENT,
payload=UploadedComponent(
component_number=self._component_number,
object_resource=destination_resource)))
else:
messages.append(
task.Message(
topic=task.Topic.CREATED_RESOURCE, payload=destination_resource))
return task.Output(additional_task_iterators=None, messages=messages)
def _existing_destination_is_valid(self, destination_resource):
"""Returns True if a completed temporary component can be reused."""
digesters = upload_util.get_digesters(
self._source_resource, destination_resource)
with upload_util.get_stream(
self._transformed_source_resource,
length=self._length,
offset=self._offset,
digesters=digesters) as stream:
stream.seek(0, whence=os.SEEK_END) # Populates digesters.
try:
upload_util.validate_uploaded_object(
digesters, destination_resource, task_status_queue=None)
return True
except command_errors.HashMismatchError:
return False
def execute(self, task_status_queue=None):
"""Performs upload."""
digesters = upload_util.get_digesters(
self._source_resource, self._destination_resource)
destination_url = self._destination_resource.storage_url
provider = destination_url.scheme
if properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool():
api = api_factory.get_api(
provider, bucket_name=destination_url.bucket_name
)
else:
api = api_factory.get_api(provider)
request_config = request_config_factory.get_request_config(
destination_url,
content_type=upload_util.get_content_type(
self._source_resource.storage_url.resource_name,
self._source_resource.storage_url.is_stream),
md5_hash=self._source_resource.md5_hash,
size=self._length,
user_request_args=self._user_request_args)
if self._component_number is None:
source_resource_for_metadata = self._source_resource
else:
source_resource_for_metadata = None
# This disables the Content-MD5 header for multi-part uploads.
request_config.resource_args.md5_hash = None
with upload_util.get_stream(
self._transformed_source_resource,
length=self._length,
offset=self._offset,
digesters=digesters,
task_status_queue=task_status_queue,
destination_resource=self._destination_resource,
component_number=self._component_number,
total_components=self._total_components) as source_stream:
upload_strategy = upload_util.get_upload_strategy(api, self._length)
if cloud_api.Capability.APPENDABLE_UPLOAD in api.capabilities:
destination_resource = api.upload_object(
source_stream,
self._destination_resource,
request_config,
posix_to_set=self._posix_to_set,
source_resource=source_resource_for_metadata,
upload_strategy=upload_strategy,
)
# DeferredCrc32c does not hash on-the-fly and needs a summation call.
if digesters.get(hash_util.HashAlgorithm.CRC32C, None) and isinstance(
digesters[hash_util.HashAlgorithm.CRC32C],
fast_crc32c_util.DeferredCrc32c,
):
digesters[hash_util.HashAlgorithm.CRC32C].sum_file(
self._source_path, self._offset, self._length
)
elif upload_strategy == cloud_api.UploadStrategy.RESUMABLE:
tracker_file_path = tracker_file_util.get_tracker_file_path(
self._destination_resource.storage_url,
tracker_file_util.TrackerFileType.UPLOAD,
component_number=self._component_number)
complete = False
encryption_key_hash_sha256 = getattr(
encryption_util.get_encryption_key(), 'sha256', None)
tracker_callback = functools.partial(
tracker_file_util.write_resumable_upload_tracker_file,
tracker_file_path, complete, encryption_key_hash_sha256)
tracker_data = tracker_file_util.read_resumable_upload_tracker_file(
tracker_file_path)
if (tracker_data is None or
tracker_data.encryption_key_sha256 != encryption_key_hash_sha256):
serialization_data = None
else:
# TODO(b/190093425): Print a better message for component uploads once
# the final destination resource is available in ComponentUploadTask.
log.status.Print(
'Resuming upload for ' + destination_url.resource_name
)
serialization_data = tracker_data.serialization_data
if tracker_data.complete:
try:
metadata_request_config = (
request_config_factory.get_request_config(
destination_url,
decryption_key_hash_sha256=encryption_key_hash_sha256))
# Providing a decryption key means the response will include the
# object's hash if the keys match, and raise an error if they do
# not. This is desirable since we want to re-upload objects with
# the wrong key, and need the object's hash for validation.
destination_resource = api.get_object_metadata(
destination_url.bucket_name, destination_url.resource_name,
metadata_request_config)
except api_errors.CloudApiError:
# Any problem fetching existing object metadata can be ignored,
# since we'll just re-upload the object.
pass
else:
# The API call will not error if we provide an encryption key but
# the destination is unencrypted, hence the additional (defensive)
# check below.
destination_key_hash = (
destination_resource.decryption_key_hash_sha256)
if (destination_key_hash == encryption_key_hash_sha256 and
self._existing_destination_is_valid(destination_resource)):
return self._get_output(destination_resource)
attempt_upload = functools.partial(
api.upload_object,
source_stream,
self._destination_resource,
request_config,
posix_to_set=self._posix_to_set,
serialization_data=serialization_data,
source_resource=source_resource_for_metadata,
tracker_callback=tracker_callback,
upload_strategy=upload_strategy,
)
def _handle_resumable_upload_error(exc_type, exc_value, exc_traceback,
state):
"""Returns true if resumable upload should retry on error argument."""
del exc_traceback # Unused.
if not (exc_type is api_errors.NotFoundError or
getattr(exc_value, 'status_code', None) == 410):
if exc_type is api_errors.ResumableUploadAbortError:
tracker_file_util.delete_tracker_file(tracker_file_path)
# Otherwise the error is probably a persistent network issue
# that is already retried by API clients, so we'll keep the tracker
# file to allow the user to retry the upload in a separate run.
return False
tracker_file_util.delete_tracker_file(tracker_file_path)
if state.retrial == 0:
# Ping bucket to see if it exists.
try:
api.get_bucket(self._destination_resource.storage_url.bucket_name)
except api_errors.CloudApiError as e:
# The user may not have permission to view the bucket metadata,
# so the ping may still be valid for access denied errors.
status = getattr(e, 'status_code', None)
if status not in (401, 403):
raise
return True
# Convert seconds to miliseconds by multiplying by 1000.
destination_resource = retry.Retryer(
max_retrials=properties.VALUES.storage.max_retries.GetInt(),
wait_ceiling_ms=properties.VALUES.storage.max_retry_delay.GetInt() *
1000,
exponential_sleep_multiplier=(
properties.VALUES.storage.exponential_sleep_multiplier.GetInt()
)).RetryOnException(
attempt_upload,
sleep_ms=properties.VALUES.storage.base_retry_delay.GetInt() *
1000,
should_retry_if=_handle_resumable_upload_error)
tracker_data = tracker_file_util.read_resumable_upload_tracker_file(
tracker_file_path)
if tracker_data is not None:
if self._component_number is not None:
tracker_file_util.write_resumable_upload_tracker_file(
tracker_file_path,
complete=True,
encryption_key_sha256=tracker_data.encryption_key_sha256,
serialization_data=tracker_data.serialization_data)
else:
tracker_file_util.delete_tracker_file(tracker_file_path)
else:
destination_resource = api.upload_object(
source_stream,
self._destination_resource,
request_config,
posix_to_set=self._posix_to_set,
source_resource=source_resource_for_metadata,
upload_strategy=upload_strategy,
)
upload_util.validate_uploaded_object(digesters, destination_resource,
task_status_queue)
return self._get_output(destination_resource)
def __eq__(self, other):
if not isinstance(other, FilePartUploadTask):
return NotImplemented
return (
self._destination_resource == other._destination_resource
and self._source_resource == other._source_resource
and self._offset == other._offset
and self._length == other._length
and self._component_number == other._component_number
and self._total_components == other._total_components
and self._posix_to_set == other._posix_to_set
and self._user_request_args == other._user_request_args
)

View File

@@ -0,0 +1,371 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for file uploads.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
import os
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import gzip_util
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import path_util
from googlecloudsdk.command_lib.storage import symlink_util
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_util
from googlecloudsdk.command_lib.storage.tasks.cp import copy_component_util
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.cp import file_part_upload_task
from googlecloudsdk.command_lib.storage.tasks.cp import finalize_composite_upload_task
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
class FileUploadTask(copy_util.ObjectCopyTaskWithExitHandler):
"""Represents a command operation triggering a file upload."""
def __init__(
self,
source_resource,
destination_resource,
delete_source=False,
is_composite_upload_eligible=False,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (resource_reference.FileObjectResource): Must contain
local filesystem path to upload object. Does not need to contain
metadata.
destination_resource (resource_reference.ObjectResource|UnknownResource):
Must contain the full object path. Directories will not be accepted.
Existing objects at the this location will be overwritten.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
is_composite_upload_eligible (bool): If True, parallel composite upload
may be performed.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
user_request_args (UserRequestArgs|None): See parent class.
verbose (bool): See parent class.
"""
super(FileUploadTask, self).__init__(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
self._delete_source = delete_source
self._is_composite_upload_eligible = is_composite_upload_eligible
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string
)
def _perform_single_transfer(
self,
size,
source_path,
task_status_queue,
temporary_paths_to_clean_up,
):
task_output = file_part_upload_task.FilePartUploadTask(
self._source_resource,
self._destination_resource,
source_path,
offset=0,
length=size,
posix_to_set=self._posix_to_set,
user_request_args=self._user_request_args,
).execute(task_status_queue)
result_resource = task_util.get_first_matching_message_payload(
task_output.messages, task.Topic.CREATED_RESOURCE
)
if result_resource:
self._print_created_message_if_requested(result_resource)
if self._send_manifest_messages:
manifest_util.send_success_message(
task_status_queue,
self._source_resource,
self._destination_resource,
md5_hash=result_resource.md5_hash,
)
for path in temporary_paths_to_clean_up:
os.remove(path)
if self._delete_source:
# Delete original source file.
os.remove(self._source_resource.storage_url.resource_name)
def _get_user_request_args_for_composite_upload_chunks(self):
"""Returns the user args to be used for composite upload chunks."""
if not self._user_request_args or not self._user_request_args.resource_args:
return self._user_request_args
user_args = copy.deepcopy(self._user_request_args)
resource_args = user_args.resource_args
# We do not want context to be uploaded for each chunk. Instead we will
# set the context once the composite object is finalized.
setattr(resource_args, 'custom_contexts_to_set', None)
setattr(resource_args, 'custom_contexts_to_remove', None)
setattr(resource_args, 'custom_contexts_to_update', None)
# We also do not want metadata to be uploaded for each chunk.
# See b/377305136 for more details.
setattr(resource_args, 'custom_fields_to_set', None)
setattr(resource_args, 'custom_fields_to_remove', None)
setattr(resource_args, 'custom_fields_to_update', None)
return user_args
def _perform_composite_upload(
self,
api_client,
component_count,
size,
source_path,
task_status_queue,
temporary_paths_to_clean_up,
):
tracker_file_path = tracker_file_util.get_tracker_file_path(
self._destination_resource.storage_url,
tracker_file_util.TrackerFileType.PARALLEL_UPLOAD,
source_url=self._source_resource.storage_url,
)
tracker_data = tracker_file_util.read_composite_upload_tracker_file(
tracker_file_path
)
if tracker_data:
random_prefix = tracker_data.random_prefix
else:
random_prefix = path_util.generate_random_int_for_path()
component_offsets_and_lengths = (
copy_component_util.get_component_offsets_and_lengths(
size, component_count
)
)
temporary_component_resources = []
for i in range(len(component_offsets_and_lengths)):
temporary_component_resource = (
copy_component_util.get_temporary_component_resource(
self._source_resource,
self._destination_resource,
random_prefix,
i,
)
)
temporary_component_resources.append(temporary_component_resource)
component_name_length = len(
temporary_component_resource.storage_url.resource_name.encode()
)
if component_name_length > api_client.MAX_OBJECT_NAME_LENGTH:
log.warning(
'Performing a non-composite upload for {}, as a temporary'
' component resource would have a name of length {}. This is'
' longer than the maximum object name length supported by this'
' API: {} UTF-8 encoded bytes. You may be able to change the'
' storage/parallel_composite_upload_prefix config option to perform'
' a composite upload with this object.'.format(
self._source_resource.storage_url,
component_name_length,
api_client.MAX_OBJECT_NAME_LENGTH,
)
)
return self._perform_single_transfer(
size,
source_path,
task_status_queue,
temporary_paths_to_clean_up,
)
file_part_upload_tasks = []
for i, (offset, length) in enumerate(component_offsets_and_lengths):
upload_task = file_part_upload_task.FilePartUploadTask(
self._source_resource,
temporary_component_resources[i],
source_path,
offset,
length,
component_number=i,
total_components=len(component_offsets_and_lengths),
user_request_args=self._get_user_request_args_for_composite_upload_chunks(),
)
file_part_upload_tasks.append(upload_task)
finalize_upload_task = (
finalize_composite_upload_task.FinalizeCompositeUploadTask(
expected_component_count=len(file_part_upload_tasks),
source_resource=self._source_resource,
destination_resource=self._destination_resource,
delete_source=self._delete_source,
posix_to_set=self._posix_to_set,
print_created_message=self._print_created_message,
random_prefix=random_prefix,
temporary_paths_to_clean_up=temporary_paths_to_clean_up,
user_request_args=self._user_request_args,
)
)
tracker_file_util.write_composite_upload_tracker_file(
tracker_file_path, random_prefix
)
return task.Output(
additional_task_iterators=[
file_part_upload_tasks,
[finalize_upload_task],
],
messages=None,
)
def _handle_symlink_placeholder_transform(
self, source_path, temporary_paths_to_clean_up
):
"""Create a symlink placeholder if necessary.
Args:
source_path (str): The source of the upload.
temporary_paths_to_clean_up (list[str]): Adds the paths of any temporary
files created to this list.
Returns:
The path to the symlink placeholder if one was created. Otherwise, returns
source_path.
"""
should_create_symlink_placeholder = (
symlink_util.get_preserve_symlink_from_user_request(
self._user_request_args
)
and self._source_resource.is_symlink
)
if should_create_symlink_placeholder:
symlink_path = symlink_util.get_symlink_placeholder_file(
self._source_resource.storage_url.resource_name
)
temporary_paths_to_clean_up.append(symlink_path)
return symlink_path
else:
return source_path
def _handle_gzip_transform(self, source_path, temporary_paths_to_clean_up):
"""Gzip the file at source_path necessary.
Args:
source_path (str): The source of the upload.
temporary_paths_to_clean_up (list[str]): Adds the paths of any temporary
files created to this list.
Returns:
The path to the gzipped temporary file if one was created. Otherwise,
returns source_path.
"""
should_gzip_locally = gzip_util.should_gzip_locally(
getattr(self._user_request_args, 'gzip_settings', None), source_path
)
if should_gzip_locally:
gzip_path = gzip_util.get_temporary_gzipped_file(source_path)
temporary_paths_to_clean_up.append(gzip_path)
return gzip_path
else:
return source_path
def execute(self, task_status_queue=None):
destination_provider = self._destination_resource.storage_url.scheme
api_client = api_factory.get_api(destination_provider)
if copy_util.check_for_cloud_clobber(
self._user_request_args, api_client, self._destination_resource
):
log.status.Print(
copy_util.get_no_clobber_message(
self._destination_resource.storage_url
)
)
if self._send_manifest_messages:
manifest_util.send_skip_message(
task_status_queue,
self._source_resource,
self._destination_resource,
copy_util.get_no_clobber_message(
self._destination_resource.storage_url
),
)
return
source_url = self._source_resource.storage_url
temporary_paths_to_clean_up = []
if source_url.is_stream:
source_path = source_url.resource_name
size = None
else:
symlink_transformed_path = self._handle_symlink_placeholder_transform(
source_url.resource_name,
temporary_paths_to_clean_up
)
source_path = self._handle_gzip_transform(
symlink_transformed_path,
temporary_paths_to_clean_up
)
size = os.path.getsize(source_path)
component_count = copy_component_util.get_component_count(
size,
properties.VALUES.storage.parallel_composite_upload_component_size.Get(),
api_client.MAX_OBJECTS_PER_COMPOSE_CALL,
)
should_perform_single_transfer = (
not self._is_composite_upload_eligible
or not task_util.should_use_parallelism()
or component_count <= 1
)
if should_perform_single_transfer:
self._perform_single_transfer(
size, source_path, task_status_queue, temporary_paths_to_clean_up
)
else:
return self._perform_composite_upload(
api_client,
component_count,
size,
source_path,
task_status_queue,
temporary_paths_to_clean_up,
)

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains logic for finalizing composite uploads."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.tasks import compose_objects_task
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_util
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.cp import delete_temporary_components_task
class FinalizeCompositeUploadTask(copy_util.ObjectCopyTaskWithExitHandler):
"""Composes and deletes object resources received as messages."""
def __init__(
self,
expected_component_count,
source_resource,
destination_resource,
delete_source=False,
posix_to_set=None,
print_created_message=False,
random_prefix='',
temporary_paths_to_clean_up=None,
user_request_args=None,
):
"""Initializes task.
Args:
expected_component_count (int): Number of temporary components expected.
source_resource (resource_reference.FileObjectResource): The local
uploaded file.
destination_resource (resource_reference.UnknownResource): Metadata for
the final composite object.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
random_prefix (str): Random id added to component names.
temporary_paths_to_clean_up (str): Paths to remove after the composite
upload completes. This may include a temporary gzipped version of the
source, or symlink placeholders.
user_request_args (UserRequestArgs|None): See parent class.
"""
super(FinalizeCompositeUploadTask, self).__init__(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
user_request_args=user_request_args,
)
self._expected_component_count = expected_component_count
self._delete_source = delete_source
self._random_prefix = random_prefix
self._temporary_paths_to_clean_up = temporary_paths_to_clean_up
def execute(self, task_status_queue=None):
uploaded_components = [
message.payload
for message in self.received_messages
if message.topic == task.Topic.UPLOADED_COMPONENT
]
if len(uploaded_components) != self._expected_component_count:
raise errors.Error(
'Temporary components were not uploaded correctly.'
' Please retry this upload.'
)
uploaded_objects = [
component.object_resource
for component in sorted(
uploaded_components,
key=lambda component: component.component_number)
]
compose_task = compose_objects_task.ComposeObjectsTask(
uploaded_objects,
self._destination_resource,
original_source_resource=self._source_resource,
posix_to_set=self._posix_to_set,
user_request_args=self._user_request_args,
)
compose_task_output = compose_task.execute(
task_status_queue=task_status_queue
)
result_resource = task_util.get_first_matching_message_payload(
compose_task_output.messages, task.Topic.CREATED_RESOURCE
)
if result_resource:
self._print_created_message_if_requested(result_resource)
if self._send_manifest_messages:
manifest_util.send_success_message(
task_status_queue,
self._source_resource,
self._destination_resource,
md5_hash=result_resource.md5_hash,
)
# After a successful compose call, we consider the upload complete and can
# delete tracker files.
tracker_file_path = tracker_file_util.get_tracker_file_path(
self._destination_resource.storage_url,
tracker_file_util.TrackerFileType.PARALLEL_UPLOAD,
source_url=self._source_resource,
)
tracker_file_util.delete_tracker_file(tracker_file_path)
for path in self._temporary_paths_to_clean_up or []:
os.remove(path)
if self._delete_source:
# Delete original source file.
os.remove(self._source_resource.storage_url.resource_name)
return task.Output(
additional_task_iterators=[
[
delete_temporary_components_task.DeleteTemporaryComponentsTask(
self._source_resource,
self._destination_resource,
self._random_prefix,
)
]
],
messages=None,
)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._expected_component_count == other._expected_component_count
and self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._random_prefix == other._random_prefix
and self._temporary_paths_to_clean_up
== other._temporary_paths_to_clean_up
and self._user_request_args == other._user_request_args
)

View File

@@ -0,0 +1,172 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for performing final steps of sliced download.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import posix_util
from googlecloudsdk.command_lib.storage import symlink_util
from googlecloudsdk.command_lib.storage import tracker_file_util
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.cp import download_util
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from googlecloudsdk.command_lib.util import crc32c
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
class FinalizeSlicedDownloadTask(copy_util.ObjectCopyTaskWithExitHandler):
"""Performs final steps of sliced download."""
def __init__(
self,
source_resource,
temporary_destination_resource,
final_destination_resource,
delete_source=False,
do_not_decompress=False,
posix_to_set=None,
print_created_message=False,
system_posix_data=None,
user_request_args=None,
):
"""Initializes task.
Args:
source_resource (resource_reference.ObjectResource): Should contain
object's metadata for checking content encoding.
temporary_destination_resource (resource_reference.FileObjectResource):
Must contain a local path to the temporary file written to during
transfers.
final_destination_resource (resource_reference.FileObjectResource): Must
contain local filesystem path to the final download destination.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
do_not_decompress (bool): Prevents automatically decompressing downloaded
gzips.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
system_posix_data (SystemPosixData): System-wide POSIX info.
user_request_args (UserRequestArgs|None): See parent class.
"""
super(FinalizeSlicedDownloadTask, self).__init__(
source_resource,
final_destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
user_request_args=user_request_args,
)
self._temporary_destination_resource = temporary_destination_resource
self._final_destination_resource = final_destination_resource
self._delete_source = delete_source
self._do_not_decompress = do_not_decompress
self._system_posix_data = system_posix_data
def execute(self, task_status_queue=None):
"""Validates and clean ups after sliced download."""
component_error_occurred = False
for message in self.received_messages:
if message.topic is task.Topic.ERROR:
log.error(message.payload)
component_error_occurred = True
if component_error_occurred:
raise errors.Error(
'Failed to download one or more component of sliced download.')
temporary_object_path = (
self._temporary_destination_resource.storage_url.resource_name)
final_destination_object_path = (
self._final_destination_resource.storage_url.resource_name)
if (properties.VALUES.storage.check_hashes.Get() !=
properties.CheckHashes.NEVER.value and
self._source_resource.crc32c_hash):
component_payloads = [
message.payload
for message in self.received_messages
if message.topic == task.Topic.CRC32C
]
if component_payloads:
# Returns list of payload values sorted by component number.
sorted_component_payloads = sorted(
component_payloads, key=lambda d: d['component_number'])
downloaded_file_checksum = sorted_component_payloads[0][
'crc32c_checksum']
for i in range(1, len(sorted_component_payloads)):
payload = sorted_component_payloads[i]
downloaded_file_checksum = crc32c.concat_checksums(
downloaded_file_checksum,
payload['crc32c_checksum'],
b_byte_count=payload['length'])
downloaded_file_hash_object = crc32c.get_crc32c_from_checksum(
downloaded_file_checksum)
downloaded_file_hash_digest = crc32c.get_hash(
downloaded_file_hash_object)
download_util.validate_download_hash_and_delete_corrupt_files(
temporary_object_path, self._source_resource.crc32c_hash,
downloaded_file_hash_digest)
preserve_symlinks = symlink_util.get_preserve_symlink_from_user_request(
self._user_request_args
)
download_util.finalize_download(
self._source_resource,
temporary_object_path,
final_destination_object_path,
convert_symlinks=preserve_symlinks,
do_not_decompress_flag=self._do_not_decompress,
)
tracker_file_util.delete_download_tracker_files(
self._temporary_destination_resource.storage_url)
posix_util.run_if_setting_posix(
self._posix_to_set,
self._user_request_args,
posix_util.set_posix_attributes_on_file_if_valid,
self._system_posix_data,
self._source_resource,
self._final_destination_resource,
known_source_posix=self._posix_to_set,
preserve_symlinks=preserve_symlinks,
)
self._print_created_message_if_requested(self._final_destination_resource)
if self._send_manifest_messages:
# Does not send md5_hash because sliced download uses CRC32C.
manifest_util.send_success_message(
task_status_queue,
self._source_resource,
self._final_destination_resource,
)
if self._delete_source:
return task.Output(
additional_task_iterators=[[
delete_task.DeleteObjectTask(self._source_resource.storage_url),
]],
messages=None,
)

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for copying an object around the cloud.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from googlecloudsdk.core import log
class IntraCloudCopyTask(copy_util.ObjectCopyTaskWithExitHandler):
"""Represents a command operation copying an object around the cloud."""
def __init__(
self,
source_resource,
destination_resource,
delete_source=False,
fetch_source_fields_scope=None,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (resource_reference.Resource): Must contain the full
object path. Directories will not be accepted.
destination_resource (resource_reference.Resource): Must contain the full
object path. Directories will not be accepted. Existing objects at the
this location will be overwritten.
delete_source (bool): If copy completes successfully, delete the source
object afterwards.
fetch_source_fields_scope (FieldsScope|None): If present, refetch
source_resource, populated with metadata determined by this FieldsScope.
Useful for lazy or parallelized GET calls.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
user_request_args (UserRequestArgs|None): See parent class
verbose (bool): See parent class.
"""
super(IntraCloudCopyTask, self).__init__(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
if ((source_resource.storage_url.scheme
!= destination_resource.storage_url.scheme)
or not isinstance(source_resource.storage_url,
storage_url.CloudUrl)):
raise errors.InvalidUrlError(
'IntraCloudCopyTask takes two URLs from the same cloud provider.'
)
self._delete_source = delete_source
self._fetch_source_fields_scope = fetch_source_fields_scope
self.parallel_processing_key = (
self._destination_resource.storage_url.url_string)
def execute(self, task_status_queue=None):
api_client = api_factory.get_api(self._source_resource.storage_url.scheme)
if copy_util.check_for_cloud_clobber(self._user_request_args, api_client,
self._destination_resource):
log.status.Print(
copy_util.get_no_clobber_message(
self._destination_resource.storage_url))
if self._send_manifest_messages:
manifest_util.send_skip_message(
task_status_queue, self._source_resource,
self._destination_resource,
copy_util.get_no_clobber_message(
self._destination_resource.storage_url))
return
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=0,
length=self._source_resource.size,
source_url=self._source_resource.storage_url,
destination_url=self._destination_resource.storage_url,
operation_name=task_status.OperationName.INTRA_CLOUD_COPYING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
if self._fetch_source_fields_scope:
copy_source = api_client.get_object_metadata(
self._source_resource.bucket,
self._source_resource.name,
generation=self._source_resource.generation,
fields_scope=self._fetch_source_fields_scope,
)
else:
copy_source = self._source_resource
request_config = request_config_factory.get_request_config(
self._destination_resource.storage_url,
decryption_key_hash_sha256=(
self._source_resource.decryption_key_hash_sha256),
user_request_args=self._user_request_args)
result_resource = api_client.copy_object(
copy_source,
self._destination_resource,
request_config,
posix_to_set=self._posix_to_set,
progress_callback=progress_callback,
)
self._print_created_message_if_requested(result_resource)
if self._send_manifest_messages:
manifest_util.send_success_message(
task_status_queue,
self._source_resource,
self._destination_resource,
md5_hash=result_resource.md5_hash)
if self._delete_source:
return task.Output(
additional_task_iterators=[
[delete_task.DeleteObjectTask(self._source_resource.storage_url)]
],
messages=None,
)
def __eq__(self, other):
if not isinstance(other, IntraCloudCopyTask):
return NotImplemented
return (
self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._delete_source == other._delete_source
and self._fetch_source_fields_scope == other._fetch_source_fields_scope
and self._posix_to_set == other._posix_to_set
and self._print_created_message == other._print_created_message
and self._print_source_version == other._print_source_version
and self._user_request_args == other._user_request_args
and self._verbose == other._verbose
)

View File

@@ -0,0 +1,204 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for parallel composite upload operation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import textwrap
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import errors
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import scaled_integer
_STANDARD_STORAGE_CLASS = 'STANDARD'
def is_destination_composite_upload_compatible(destination_resource,
user_request_args):
"""Checks if destination bucket is compatible for parallel composite upload.
This function performs a GET bucket call to determine if the bucket's default
storage class and retention period meet the criteria.
Args:
destination_resource(CloudResource|UnknownResource):
Destination resource to which the files should be uploaded.
user_request_args (UserRequestArgs|None): Values from user flags.
Returns:
True if the bucket satisfies the storage class and retention policy
criteria.
"""
api_client = api_factory.get_api(destination_resource.storage_url.scheme)
try:
bucket_resource = api_client.get_bucket(
destination_resource.storage_url.bucket_name)
except errors.CloudApiError as e:
status = getattr(e, 'status_code', None)
if status in (401, 403):
log.error(
'Cannot check if the destination bucket is compatible for running'
' parallel composite uploads as the user does not permission to'
' perform GET operation on the bucket. The operation will be'
' performed without parallel composite upload feature and hence'
' might perform relatively slower.')
return False
else:
raise
resource_args = getattr(user_request_args, 'resource_args', None)
object_storage_class = getattr(resource_args, 'storage_class', None)
if bucket_resource.retention_period is not None:
reason = 'Destination bucket has retention period set'
elif bucket_resource.default_event_based_hold:
reason = 'Destination bucket has event-based hold set'
elif getattr(resource_args, 'event_based_hold', None):
reason = 'Object will be created with event-based hold'
elif getattr(resource_args, 'temporary_hold', None):
reason = 'Object will be created with temporary hold'
elif (bucket_resource.default_storage_class != _STANDARD_STORAGE_CLASS and
object_storage_class != _STANDARD_STORAGE_CLASS):
reason = 'Destination has a default storage class other than "STANDARD"'
elif object_storage_class not in (None, _STANDARD_STORAGE_CLASS):
reason = 'Object will be created with a storage class other than "STANDARD"'
else:
return True
log.warning(
'{}, hence parallel'
' composite upload will not be performed. If you would like to disable'
' this check, run: gcloud config set '
'storage/parallel_composite_upload_compatibility_check=False'.format(
reason))
return False
def is_composite_upload_eligible(source_resource,
destination_resource,
user_request_args=None):
"""Checks if parallel composite upload should be performed.
Logs tailored warning based on user configuration and the context
of the operation.
Informs user about configuration options they may want to set.
In order to avoid repeated warning raised for each task,
this function updates the storage/parallel_composite_upload_enabled
so that the warnings are logged only once.
Args:
source_resource (FileObjectResource): The source file
resource to be uploaded.
destination_resource(CloudResource|UnknownResource):
Destination resource to which the files should be uploaded.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
Returns:
True if the parallel composite upload can be performed. However, this does
not guarantee that parallel composite upload will be performed as the
parallelism check can happen only after the task executor starts running
because it sets the process_count and thread_count. We also let the task
determine the component count.
"""
composite_upload_enabled = (
properties.VALUES.storage.parallel_composite_upload_enabled.GetBool())
if composite_upload_enabled is False: # pylint: disable=g-bool-id-comparison
# Can't do "if not composite_upload_enabled" here because
# None has a different behavior.
return False
if not isinstance(source_resource, resource_reference.FileObjectResource):
# Source resource can be of type UnknownResource, hence check the type.
return False
try:
if (source_resource.size is None or
source_resource.size < scaled_integer.ParseInteger(
properties.VALUES.storage.parallel_composite_upload_threshold.Get()
)):
return False
except OSError as e:
log.warning('Size cannot be determined for resource: %s. Error: %s',
source_resource, e)
return False
compatibility_check_required = (
properties.VALUES.storage.parallel_composite_upload_compatibility_check
.GetBool())
if composite_upload_enabled and not compatibility_check_required:
return True
api_capabilities = api_factory.get_capabilities(
destination_resource.storage_url.scheme,
bucket_name=(
destination_resource.storage_url.bucket_name
if properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool()
else None
),
)
if cloud_api.Capability.COMPOSE_OBJECTS not in api_capabilities:
# We can silently disable parallel composite upload because the destination
# capability will not change during the execution.
# TODO(b/245738490) Explore if setting this property can be avoided.
properties.VALUES.storage.parallel_composite_upload_enabled.Set(False)
return False
if compatibility_check_required:
can_perform_composite_upload = (
is_destination_composite_upload_compatible(destination_resource,
user_request_args))
# Indicates that we don't have to repeat compatibility check.
properties.VALUES.storage.parallel_composite_upload_compatibility_check.Set(
False)
else:
can_perform_composite_upload = True
if can_perform_composite_upload and composite_upload_enabled is None:
log.warning(
'\n'.join(
textwrap.fill(l, width=70)
for l in (
'Parallel composite upload was turned ON to get the best'
' performance on uploading large objects. If you would like to'
' opt-out and instead perform a normal upload, run:'
'\n`gcloud config set storage/parallel_composite_upload_enabled'
' False`'
'\nIf you would like to disable this warning, run:'
'\n`gcloud config set storage/parallel_composite_upload_enabled'
' True`'
# We say "might" here because whether parallel composite upload
# is used or not also depends on whether parallelism is True.
'\nNote that with parallel composite uploads, your object might'
' be uploaded as a composite object'
' (https://cloud.google.com/storage/docs/composite-objects),'
' which means that any user who downloads your object will'
' need to use crc32c checksums to verify data integrity. gcloud'
' storage is capable of computing crc32c checksums, but this'
' might pose a problem for other clients.'
).splitlines()
)
+ '\n'
)
# TODO(b/245738490) Explore if setting this property can be avoided.
properties.VALUES.storage.parallel_composite_upload_enabled.Set(
can_perform_composite_upload)
return can_perform_composite_upload

View File

@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for streaming downloads.
Typically executed in a task iterator:
googlecloudsdk.command_lib.storage.tasks.task_executor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
import sys
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.api_lib.storage.gcs_grpc_bidi_streaming import client as grpc_bidi_streaming_client
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import properties
class StreamingDownloadTask(copy_util.ObjectCopyTask):
"""Represents a command operation triggering a streaming download."""
def __init__(
self,
source_resource,
destination_resource,
download_stream,
print_created_message=False,
print_source_version=False,
show_url=False,
start_byte=0,
end_byte=None,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (ObjectResource): Must contain the full path of object to
download, including bucket. Directories will not be accepted. Does not
need to contain metadata.
destination_resource (resource_reference.Resource): Target resource to
copy to. In this case, it contains the path of the destination stream or
'-' for stdout.
download_stream (stream): Reusable stream to write download to.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
show_url (bool): Says whether or not to print the header before each
object's content
start_byte (int): The byte index to start streaming from.
end_byte (int|None): The byte index to stop streaming from.
user_request_args (UserRequestArgs|None): See parent class.
verbose (bool): See parent class.
"""
super(StreamingDownloadTask, self).__init__(
source_resource,
destination_resource,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
self._download_stream = download_stream
self._show_url = show_url
self._start_byte = start_byte
self._end_byte = end_byte
def execute(self, task_status_queue=None):
if self._show_url:
sys.stderr.write('==> {} <==\n'.format(self._source_resource))
if task_status_queue:
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=0,
length=self._source_resource.size,
source_url=self._source_resource.storage_url,
destination_url=self._download_stream.name,
operation_name=task_status.OperationName.DOWNLOADING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
else:
progress_callback = None
if (self._source_resource.size and
self._start_byte >= self._source_resource.size):
if progress_callback:
progress_callback(self._source_resource.size)
return
request_config = request_config_factory.get_request_config(
self._source_resource.storage_url,
decryption_key_hash_sha256=(
self._source_resource.decryption_key_hash_sha256),
user_request_args=self._user_request_args,
)
args = [self._source_resource.storage_url.scheme]
if properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool():
args.append(self._source_resource.storage_url.bucket_name)
api = api_factory.get_api(*args)
if isinstance(api, grpc_bidi_streaming_client.GcsGrpcBidiStreamingClient):
raise core_exceptions.InternalError(
'Only Simple/Sliced downloads are supported for zonal buckets via'
' Grpc Bidi Streaming API.'
)
api.download_object(
self._source_resource,
self._download_stream,
request_config,
download_strategy=cloud_api.DownloadStrategy.ONE_SHOT,
progress_callback=progress_callback,
start_byte=self._start_byte,
end_byte=self._end_byte)
self._download_stream.flush()
self._print_created_message_if_requested(self._destination_resource)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._download_stream == other._download_stream
and self._print_created_message == other._print_created_message
and self._user_request_args == other._user_request_args
and self._show_url == other._show_url
and self._start_byte == other._start_byte
and self._end_byte == other._end_byte
)

View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for streaming uploads."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import user_request_args_factory
from googlecloudsdk.command_lib.storage.tasks.cp import copy_util
from googlecloudsdk.command_lib.storage.tasks.cp import upload_util
from googlecloudsdk.core import properties
class StreamingUploadTask(copy_util.ObjectCopyTask):
"""Represents a command operation triggering a streaming upload."""
def __init__(
self,
source_resource,
destination_resource,
posix_to_set=None,
print_created_message=False,
print_source_version=False,
user_request_args=None,
verbose=False,
):
"""Initializes task.
Args:
source_resource (FileObjectResource): Points to the stream or named pipe
to read from.
destination_resource (UnknownResource|ObjectResource): The full path of
object to upload to.
posix_to_set (PosixAttributes|None): See parent class.
print_created_message (bool): See parent class.
print_source_version (bool): See parent class.
user_request_args (UserRequestArgs|None): See parent class.
verbose (bool): See parent class.
"""
super(StreamingUploadTask, self).__init__(
source_resource,
destination_resource,
posix_to_set=posix_to_set,
print_created_message=print_created_message,
print_source_version=print_source_version,
user_request_args=user_request_args,
verbose=verbose,
)
self._source_resource = source_resource
self._destination_resource = destination_resource
def execute(self, task_status_queue=None):
"""Runs upload from stream."""
request_config = request_config_factory.get_request_config(
self._destination_resource.storage_url,
content_type=upload_util.get_content_type(
self._source_resource.storage_url.resource_name, is_stream=True),
md5_hash=self._source_resource.md5_hash,
user_request_args=self._user_request_args)
if getattr(request_config, 'gzip_settings', None):
gzip_type = getattr(request_config.gzip_settings, 'type', None)
if gzip_type is user_request_args_factory.GzipType.LOCAL:
# TODO(b/202729249): Can support this after dropping Python 2.
raise errors.Error(
'Gzip content encoding is not currently supported for streaming'
' uploads. Remove the compression flag or save the streamed output'
' to a file before uploading.')
digesters = upload_util.get_digesters(
self._source_resource,
self._destination_resource)
stream = upload_util.get_stream(
self._source_resource,
digesters=digesters,
task_status_queue=task_status_queue,
destination_resource=self._destination_resource)
with stream:
provider = self._destination_resource.storage_url.scheme
if properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool():
api = api_factory.get_api(
provider,
bucket_name=self._destination_resource.storage_url.bucket_name,
)
else:
api = api_factory.get_api(provider)
uploaded_object_resource = api.upload_object(
source_stream=stream,
destination_resource=self._destination_resource,
request_config=request_config,
posix_to_set=self._posix_to_set,
source_resource=self._source_resource,
upload_strategy=cloud_api.UploadStrategy.STREAMING,
)
upload_util.validate_uploaded_object(
digesters,
uploaded_object_resource,
task_status_queue)
self._print_created_message_if_requested(uploaded_object_resource)

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for tasks that upload files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.storage.tasks import task
class UploadTask(task.Task):
"""Base class for tasks that upload files."""
def __init__(self, source_resource, destination_resource, length):
"""Initializes a task instance.
Args:
source_resource (resource_reference.FileObjectResource): The file to
upload.
destination_resource (resource_reference.ObjectResource|UnknownResource):
Destination metadata for the upload.
length (int): The size of source_resource.
"""
super(UploadTask, self).__init__()
self._source_resource = source_resource
self._destination_resource = destination_resource
self._length = length
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self._source_resource == other._source_resource and
self._destination_resource == other._destination_resource and
self._length == other._length
)

View File

@@ -0,0 +1,249 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for performing upload operation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import mimetypes
import os
import subprocess
import threading
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import buffered_upload_stream
from googlecloudsdk.command_lib.storage import component_stream
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import fast_crc32c_util
from googlecloudsdk.command_lib.storage import hash_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage import upload_stream
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from googlecloudsdk.command_lib.util import crc32c
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import files
from googlecloudsdk.core.util import hashing
from googlecloudsdk.core.util import platforms
from googlecloudsdk.core.util import scaled_integer
COMMON_EXTENSION_RULES = {
'.md': 'text/markdown', # b/169088193
'.tgz': 'application/gzip', # b/179176339
}
def get_upload_strategy(api, object_length):
"""Determines if resumbale uplaod should be performed.
Args:
api (CloudApi): An api instance to check if it supports resumable upload.
object_length (int): Length of the data to be uploaded.
Returns:
bool: True if resumable upload can be performed.
"""
resumable_threshold = scaled_integer.ParseInteger(
properties.VALUES.storage.resumable_threshold.Get())
if (object_length >= resumable_threshold and
cloud_api.Capability.RESUMABLE_UPLOAD in api.capabilities):
return cloud_api.UploadStrategy.RESUMABLE
else:
return cloud_api.UploadStrategy.SIMPLE
def get_content_type(source_path, is_stream):
"""Gets a file's MIME type.
Favors returning the result of `file -b --mime ...` if the command is
available and users have enabled it. Otherwise, it returns a type based on the
file's extension.
Args:
source_path (str): Path to file. May differ from file_resource.storage_url
if using a temporary file (e.g. for gzipping).
is_stream (bool): If the source file is a pipe (typically FIFO or stdin).
Returns:
A MIME type (str).
If a type cannot be guessed, request_config_factory.DEFAULT_CONTENT_TYPE is
returned.
"""
if is_stream:
return request_config_factory.DEFAULT_CONTENT_TYPE
# Some common extensions are not recognized by the mimetypes library and
# "file" command, so we'll hard-code support for them.
for extension, content_type in COMMON_EXTENSION_RULES.items():
if source_path.endswith(extension):
return content_type
if (not platforms.OperatingSystem.IsWindows() and
properties.VALUES.storage.use_magicfile.GetBool()):
output = subprocess.run(['file', '-b', '--mime', source_path],
check=True,
stdout=subprocess.PIPE,
universal_newlines=True)
content_type = output.stdout.strip()
else:
content_type, _ = mimetypes.guess_type(source_path)
if content_type:
return content_type
return request_config_factory.DEFAULT_CONTENT_TYPE
def get_digesters(source_resource, destination_resource):
"""Gets appropriate hash objects for upload validation.
Args:
source_resource (resource_reference.FileObjectResource): The upload source.
destination_resource (resource_reference.ObjectResource): The upload
destination.
Returns:
A dict[hash_util.HashAlgorithm, hash object], the values of which should be
updated with uploaded bytes.
"""
provider = destination_resource.storage_url.scheme
bucket_name = (
destination_resource.storage_url.bucket_name
if properties.VALUES.storage.enable_zonal_buckets_bidi_streaming.GetBool()
else None
)
capabilities = api_factory.get_capabilities(provider, bucket_name)
check_hashes = properties.CheckHashes(
properties.VALUES.storage.check_hashes.Get())
# If the API supports appendable uploads, we should use CRC32C only.
if cloud_api.Capability.APPENDABLE_UPLOAD in capabilities:
fast_crc32c_util.log_or_raise_crc32c_issues()
if (
check_hashes == properties.CheckHashes.ALWAYS
or fast_crc32c_util.check_if_will_use_fast_crc32c()
):
return {hash_util.HashAlgorithm.CRC32C: fast_crc32c_util.get_crc32c()}
return {}
if (source_resource.md5_hash or
cloud_api.Capability.CLIENT_SIDE_HASH_VALIDATION in capabilities or
check_hashes == properties.CheckHashes.NEVER):
return {}
return {hash_util.HashAlgorithm.MD5: hashing.get_md5()}
def get_stream(source_resource,
length=None,
offset=None,
digesters=None,
task_status_queue=None,
destination_resource=None,
component_number=None,
total_components=None):
"""Gets a stream to use for an upload.
Args:
source_resource (resource_reference.FileObjectResource): Contains a path to
the source file.
length (int|None): The total number of bytes to be uploaded.
offset (int|None): The position of the first byte to be uploaded.
digesters (dict[hash_util.HashAlgorithm, hash object]|None): Hash objects to
be populated as bytes are read.
task_status_queue (multiprocessing.Queue|None): Used for sending progress
messages. If None, no messages will be generated or sent.
destination_resource (resource_reference.ObjectResource|None): The upload
destination. Used for progress reports, and should be specified if
task_status_queue is.
component_number (int|None): Identifies a component in composite uploads.
total_components (int|None): The total number of components used in a
composite upload.
Returns:
An UploadStream wrapping the file specified by source_resource.
"""
if task_status_queue:
progress_callback = progress_callbacks.FilesAndBytesProgressCallback(
status_queue=task_status_queue,
offset=offset or 0,
length=length,
source_url=source_resource.storage_url,
destination_url=destination_resource.storage_url,
component_number=component_number,
total_components=total_components,
operation_name=task_status.OperationName.UPLOADING,
process_id=os.getpid(),
thread_id=threading.get_ident(),
)
else:
progress_callback = None
if source_resource.storage_url.is_stdio:
source_stream = os.fdopen(0, 'rb')
else:
source_stream = files.BinaryFileReader(
source_resource.storage_url.resource_name)
if source_resource.storage_url.is_stream:
max_buffer_size = scaled_integer.ParseBinaryInteger(
properties.VALUES.storage.upload_chunk_size.Get())
return buffered_upload_stream.BufferedUploadStream(
source_stream,
max_buffer_size=max_buffer_size,
digesters=digesters,
progress_callback=progress_callback)
elif offset is None:
return upload_stream.UploadStream(
source_stream,
length=length,
digesters=digesters,
progress_callback=progress_callback)
else:
return component_stream.ComponentStream(
source_stream, offset=offset, length=length, digesters=digesters,
progress_callback=progress_callback)
def validate_uploaded_object(digesters, uploaded_resource, task_status_queue):
"""Raises error if hashes for uploaded_resource and digesters do not match."""
if not digesters:
return
if hash_util.HashAlgorithm.MD5 in digesters:
calculated_digest = hash_util.get_base64_hash_digest_string(
digesters[hash_util.HashAlgorithm.MD5])
destination_hash = uploaded_resource.md5_hash
elif hash_util.HashAlgorithm.CRC32C in digesters:
calculated_digest = crc32c.get_hash(
digesters[hash_util.HashAlgorithm.CRC32C]
)
destination_hash = uploaded_resource.crc32c_hash
else:
raise errors.Error(
'Unsupported hash algorithm(s) found in digesters: {}'.format(
', '.join(digesters.keys())
)
)
try:
hash_util.validate_object_hashes_match(
uploaded_resource.storage_url.url_string, calculated_digest,
destination_hash)
except errors.HashMismatchError:
delete_task.DeleteObjectTask(uploaded_resource.storage_url).execute(
task_status_queue=task_status_queue
)
raise

View File

@@ -0,0 +1,217 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for listing, sorting, and writing files for rsync."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import errno
import heapq
import itertools
import os
import threading
from googlecloudsdk.api_lib.storage import cloud_api
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import folder_util
from googlecloudsdk.command_lib.storage import regex_util
from googlecloudsdk.command_lib.storage import rsync_command_util
from googlecloudsdk.command_lib.storage import storage_url
from googlecloudsdk.command_lib.storage import wildcard_iterator
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import files
def sorting_key_for_csv_line(csv_line):
"""Returns the sorting key for the chunk CSV line.
This key is based on only the first field of the CSV line, which is the URL of
the resource. Since we use comma as a delimiter, we can't use the
entire CSV line as the key since there are unicode characters before the comma
like #, $, " which if present in the csv line can cause sorting issues.
Args:
csv_line (str): The CSV line to get the sorting key for.
Returns:
str: The sorting key for the CSV line.
"""
fields = rsync_command_util.get_fields_from_csv_line(csv_line)
return fields[0]
class GetSortedContainerContentsTask(task.Task):
"""Updates a local file's POSIX metadata."""
def __init__(
self,
container,
output_path,
exclude_pattern_strings=None,
managed_folders_only=False,
ignore_symlinks=True,
recurse=False,
):
"""Initializes task.
Args:
container (Resource): Contains path of files to fetch.
output_path (str): Where to write final sorted file list.
exclude_pattern_strings (List[str]|None): Ignore resources whose paths
matched these regex patterns.
managed_folders_only (bool): If True, populates the file with managed
folders. Otherwise, populates the file with object resources.
ignore_symlinks (bool): Should FileWildcardIterator skip symlinks.
recurse (bool): Gather nested items in container.
"""
super(GetSortedContainerContentsTask, self).__init__()
self._container_query_path = container.storage_url.join(
'**' if recurse else '*'
).url_string
self._output_path = output_path
if exclude_pattern_strings:
container_url_trailing_delimiter = container.storage_url.join('')
if isinstance(container_url_trailing_delimiter, storage_url.FileUrl):
# Remove 'file://' prefix.
container_prefix = container_url_trailing_delimiter.resource_name
else:
container_prefix = (
container_url_trailing_delimiter.versionless_url_string
)
self._exclude_patterns = regex_util.Patterns(
exclude_pattern_strings,
# Confirm container URL ends in a delimiter.
ignore_prefix_length=len(container_prefix),
)
else:
self._exclude_patterns = None
self._managed_folders_only = managed_folders_only
self._ignore_symlinks = ignore_symlinks
self._worker_id = 'process {} thread {}'.format(
os.getpid(), threading.get_ident()
)
def execute(self, task_status_queue=None):
del task_status_queue # Unused.
if self._managed_folders_only:
managed_folder_setting = (
folder_util.ManagedFolderSetting.LIST_WITHOUT_OBJECTS
)
else:
managed_folder_setting = folder_util.ManagedFolderSetting.DO_NOT_LIST
file_iterator = iter(
wildcard_iterator.get_wildcard_iterator(
self._container_query_path,
exclude_patterns=self._exclude_patterns,
fetch_encrypted_object_hashes=(
properties.VALUES.storage.check_hashes.Get()
!= properties.CheckHashes.NEVER.value
),
fields_scope=cloud_api.FieldsScope.RSYNC,
files_only=not self._managed_folders_only,
force_include_hidden_files=True,
ignore_symlinks=self._ignore_symlinks,
managed_folder_setting=managed_folder_setting,
)
)
chunk_count = file_count = 0
chunk_file_paths = []
chunk_file_readers = []
chunk_size = properties.VALUES.storage.rsync_list_chunk_size.GetInt()
try:
while True:
resources_chunk = list(itertools.islice(file_iterator, chunk_size))
if not resources_chunk:
break
chunk_count += 1
file_count += len(resources_chunk)
log.status.Print(
'At {}, worker {} listed {}...'.format(
self._container_query_path, self._worker_id, file_count
)
)
chunk_file_paths.append(
rsync_command_util.get_hashed_list_file_path(
self._container_query_path,
chunk_count,
is_managed_folder_list=self._managed_folders_only,
)
)
if properties.VALUES.storage.use_url_based_rsync_sorting.GetBool():
sorted_encoded_chunk = sorted(
[
rsync_command_util.get_csv_line_from_resource(x)
for x in resources_chunk
],
key=sorting_key_for_csv_line,
)
else:
sorted_encoded_chunk = sorted([
rsync_command_util.get_csv_line_from_resource(x)
for x in resources_chunk
])
sorted_encoded_chunk.append('') # Add trailing newline.
files.WriteFileContents(
chunk_file_paths[-1],
'\n'.join(sorted_encoded_chunk),
)
chunk_file_readers = [files.FileReader(path) for path in chunk_file_paths]
with files.FileWriter(self._output_path, create_path=True) as file_writer:
if properties.VALUES.storage.use_url_based_rsync_sorting.GetBool():
file_writer.writelines(
heapq.merge(*chunk_file_readers, key=sorting_key_for_csv_line)
)
else:
file_writer.writelines(heapq.merge(*chunk_file_readers))
except OSError as e:
if e.errno == errno.EMFILE:
raise errors.Error(
'Too many open chunk files. Try increasing the'
' size with `gcloud config set storage/rsync_list_chunk_size`.'
' The current size is {}.'.format(chunk_size)
)
raise e
finally:
for reader in chunk_file_readers:
try:
reader.close()
except Exception as e: # pylint:disable=broad-except
log.debug('Failed to close file reader {}: {}'.format(reader.name, e))
for path in chunk_file_paths:
rsync_command_util.try_to_delete_file(path)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._container_query_path == other._container_query_path
and self._exclude_patterns == other._exclude_patterns
and self._managed_folders_only == other._managed_folders_only
and self._ignore_symlinks == other._ignore_symlinks
and self._output_path == other._output_path
)

View File

@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for bulk restoring soft-deleted objects."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class BulkRestoreObjectsTask(task.Task):
"""Restores soft-deleted cloud storage objects."""
def __init__(
self,
bucket_url,
object_globs,
allow_overwrite=False,
created_after_time=None,
created_before_time=None,
deleted_after_time=None,
deleted_before_time=None,
user_request_args=None,
):
"""Initializes task.
Args:
bucket_url (StorageUrl): Launch a bulk restore operation for this bucket.
object_globs (list[str]): Objects in the target bucket matching these glob
patterns will be restored.
allow_overwrite (bool): Overwrite existing live objects.
created_after_time (datetime): Filter results to objects created after
this time.
created_before_time (datetime): Filter results to objects created before
this time.
deleted_after_time (datetime): Filter results to objects soft-deleted
after this time.
deleted_before_time (datetime): Filter results to objects soft-deleted
before this time.
user_request_args (UserRequestArgs|None): Contains restore settings.
"""
super(BulkRestoreObjectsTask, self).__init__()
self._bucket_url = bucket_url
self._object_globs = object_globs
self._allow_overwrite = allow_overwrite
self._created_after_time = created_after_time
self._created_before_time = created_before_time
self._deleted_after_time = deleted_after_time
self._deleted_before_time = deleted_before_time
self._user_request_args = user_request_args
def execute(self, task_status_queue=None):
log.status.Print(
'Creating bulk restore operation for bucket {} with globs: {}'.format(
self._bucket_url, self._object_globs
)
)
request_config = request_config_factory.get_request_config(
# Arbitrarily use first glob to get CloudUrl for object.
self._bucket_url.join(self._object_globs[0]),
user_request_args=self._user_request_args,
)
created_operation = api_factory.get_api(
self._bucket_url.scheme
).bulk_restore_objects(
self._bucket_url,
self._object_globs,
request_config=request_config,
allow_overwrite=self._allow_overwrite,
created_after_time=self._created_after_time,
created_before_time=self._created_before_time,
deleted_after_time=self._deleted_after_time,
deleted_before_time=self._deleted_before_time,
)
log.status.Print('Created: ' + created_operation.name)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._bucket_url == other._bucket_url
and self._object_globs == other._object_globs
and self._allow_overwrite == other._allow_overwrite
and self._created_after_time == other._created_after_time
and self._created_before_time == other._created_before_time
and self._deleted_after_time == other._deleted_after_time
and self._deleted_before_time == other._deleted_before_time
and self._user_request_args == other._user_request_args
)

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for updating an object's metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class PatchObjectTask(task.Task):
"""Updates a cloud storage object's metadata."""
def __init__(
self, object_resource, posix_to_set=None, user_request_args=None
):
"""Initializes task.
Args:
object_resource (resource_reference.ObjectResource): The object to update.
posix_to_set (PosixAttributes|None): POSIX info set as custom cloud
metadata on target.
user_request_args (UserRequestArgs|None): Describes metadata updates to
perform.
"""
super(PatchObjectTask, self).__init__()
self._object_resource = object_resource
self._posix_to_set = posix_to_set
self._user_request_args = user_request_args
def execute(self, task_status_queue=None):
log.status.Print('Patching {}...'.format(self._object_resource))
provider = self._object_resource.storage_url.scheme
request_config = request_config_factory.get_request_config(
self._object_resource.storage_url,
user_request_args=self._user_request_args)
api_factory.get_api(provider).patch_object_metadata(
self._object_resource.storage_url.bucket_name,
self._object_resource.storage_url.resource_name,
self._object_resource,
request_config=request_config,
posix_to_set=self._posix_to_set,
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._object_resource == other._object_resource
and self._posix_to_set == other._posix_to_set
and self._user_request_args == other._user_request_args
)

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for restoring a soft-deleted object."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class RestoreObjectTask(task.Task):
"""Restores a soft-deleted cloud storage object."""
def __init__(self, object_resource, user_request_args=None):
"""Initializes task.
Args:
object_resource (resource_reference.ObjectResource): Object to restore.
user_request_args (UserRequestArgs|None): Contains restore settings.
"""
super(RestoreObjectTask, self).__init__()
self._object_resource = object_resource
self._user_request_args = user_request_args
def execute(self, task_status_queue=None):
log.status.Print('Restoring {}...'.format(self._object_resource))
provider = self._object_resource.storage_url.scheme
request_config = request_config_factory.get_request_config(
self._object_resource.storage_url,
user_request_args=self._user_request_args,
)
api_factory.get_api(provider).restore_object(
self._object_resource.storage_url,
request_config=request_config,
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._object_resource == other._object_resource
and self._user_request_args == other._user_request_args
)

View File

@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for rewriting an object's underlying data to update the metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import encryption_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks.objects import patch_object_task
from googlecloudsdk.core import log
class RewriteObjectTask(task.Task):
"""Rewrites a cloud storage object's underlying data, changing metadata."""
def __init__(self, object_resource, user_request_args=None):
"""Initializes task.
Args:
object_resource (resource_reference.ObjectResource): The object to update.
user_request_args (UserRequestArgs|None): Describes metadata updates to
perform.
"""
super(RewriteObjectTask, self).__init__()
self._object_resource = object_resource
self._user_request_args = user_request_args
def execute(self, task_status_queue=None):
log.status.Print('Rewriting {}...'.format(self._object_resource))
provider = self._object_resource.storage_url.scheme
request_config = request_config_factory.get_request_config(
self._object_resource.storage_url,
user_request_args=self._user_request_args)
api_client = api_factory.get_api(provider)
existing_object_resource = api_client.get_object_metadata(
self._object_resource.storage_url.bucket_name,
self._object_resource.storage_url.resource_name,
generation=self._object_resource.storage_url.generation,
request_config=request_config)
if existing_object_resource.kms_key: # Existing CMEK.
encryption_changing = existing_object_resource.kms_key != getattr(
encryption_util.get_encryption_key(), 'key', None)
elif existing_object_resource.decryption_key_hash_sha256: # Existing CSEK.
encryption_changing = (
existing_object_resource.decryption_key_hash_sha256 != getattr(
encryption_util.get_encryption_key(), 'sha256', None))
else: # No existing encryption.
# Clear flag can still reset an object to bucket's default encryption.
encryption_changing = encryption_util.get_encryption_key() is not None
new_storage_class = getattr(request_config.resource_args, 'storage_class',
None)
storage_class_changing = (
new_storage_class and
new_storage_class != existing_object_resource.storage_class)
if not (encryption_changing or storage_class_changing):
log.warning('Proposed encryption key and storage class for' +
' {} match the existing data.'.format(self._object_resource) +
' Performing patch instead of rewrite.')
return task.Output(
additional_task_iterators=[
[
patch_object_task.PatchObjectTask(
self._object_resource,
user_request_args=self._user_request_args,
)
]
],
messages=None,
)
if storage_class_changing and not encryption_changing:
# Preserve current encryption.
new_encryption_key = encryption_util.get_encryption_key(
existing_object_resource.decryption_key_hash_sha256,
self._object_resource.storage_url)
else:
new_encryption_key = encryption_util.get_encryption_key()
request_config_with_encryption = request_config_factory.get_request_config(
self._object_resource.storage_url,
user_request_args=self._user_request_args,
decryption_key_hash_sha256=existing_object_resource
.decryption_key_hash_sha256,
encryption_key=new_encryption_key)
api_client.copy_object(
existing_object_resource,
self._object_resource,
request_config_with_encryption,
should_deep_copy_metadata=True)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (self._object_resource == other._object_resource and
self._user_request_args == other._user_request_args)

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for updating a local file's POSIX metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.storage import posix_util
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class PatchFilePosixTask(task.Task):
"""Updates a local file's POSIX metadata."""
def __init__(
self,
system_posix_data,
source_resource,
destination_resource,
known_source_posix=None,
known_destination_posix=None,
):
"""Initializes task.
Args:
system_posix_data (SystemPosixData): Contains system-wide POSIX metadata.
source_resource (resource_reference.ObjectResource): Contains custom POSIX
metadata and URL for error logging.
destination_resource (resource_reference.FileObjectResource): File to set
POSIX metadata on.
known_source_posix (PosixAttributes|None): Use pre-parsed POSIX data
instead of extracting from source.
known_destination_posix (PosixAttributes|None): Use pre-parsed POSIX data
instead of extracting from destination.
"""
super(PatchFilePosixTask, self).__init__()
self._system_posix_data = system_posix_data
self._source_resource = source_resource
self._destination_resource = destination_resource
self._known_source_posix = known_source_posix
self._known_destination_posix = known_destination_posix
def execute(self, task_status_queue=None):
log.status.Print('Patching {}...'.format(self._destination_resource))
posix_util.set_posix_attributes_on_file_if_valid(
self._system_posix_data,
self._source_resource,
self._destination_resource,
known_source_posix=self._known_source_posix,
known_destination_posix=self._known_destination_posix,
)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return (
self._system_posix_data == other._system_posix_data
and self._source_resource == other._source_resource
and self._destination_resource == other._destination_resource
and self._known_source_posix == other._known_source_posix
and self._known_destination_posix == other._known_destination_posix
)

View File

@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tasks for deleting resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
import os
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.api_lib.storage import request_config_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.core import log
class DeleteTask(task.Task):
"""Base class for tasks that delete a resource."""
def __init__(self, url, user_request_args=None, verbose=True):
"""Initializes task.
Args:
url (storage_url.StorageUrl): URL of the resource to delete.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
verbose (bool): If true, prints status messages. Otherwise, does not print
anything.
"""
super().__init__()
self._url = url
self._user_request_args = user_request_args
self._verbose = verbose
self.parallel_processing_key = url.url_string
@abc.abstractmethod
def _perform_deletion(self):
"""Deletes a resource. Overridden by children."""
raise NotImplementedError
def execute(self, task_status_queue=None):
if self._verbose:
log.status.Print('Removing {}...'.format(self._url))
self._perform_deletion()
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self._url == other._url
and self._user_request_args == other._user_request_args
and self._verbose == other._verbose
)
class DeleteFileTask(DeleteTask):
"""Task to delete a file."""
def _perform_deletion(self):
os.remove(self._url.resource_name)
class CloudDeleteTask(DeleteTask):
"""Base class for tasks that delete a cloud resource."""
@abc.abstractmethod
def _make_delete_api_call(self, client, request_config):
"""Performs an API call to delete a resource. Overridden by children."""
raise NotImplementedError
def _perform_deletion(self):
client = api_factory.get_api(self._url.scheme)
request_config = request_config_factory.get_request_config(
self._url, user_request_args=self._user_request_args
)
return self._make_delete_api_call(client, request_config)
class DeleteBucketTask(CloudDeleteTask):
"""Task to delete a bucket."""
def _make_delete_api_call(self, client, request_config):
try:
client.delete_bucket(self._url.bucket_name, request_config)
# pylint:disable=broad-except
except Exception as error:
# pylint:enable=broad-except
if 'not empty' in str(error):
raise type(error)(
'Bucket is not empty. To delete all objects and then delete'
' bucket, use: gcloud storage rm -r'
)
else:
raise
class DeleteManagedFolderTask(CloudDeleteTask):
"""Task to delete a managed folder."""
@property
def managed_folder_url(self):
"""The URL of the resource deleted by this task.
Exposing this allows execution to respect containment order.
"""
return self._url
def _make_delete_api_call(self, client, request_config):
del request_config # Unused.
client.delete_managed_folder(self._url.bucket_name, self._url.resource_name)
class DeleteFolderTask(CloudDeleteTask):
"""Task to delete a folder."""
@property
def folder_url(self):
"""The URL of the resource deleted by this task.
Exposing this allows execution to respect containment order.
"""
return self._url
def _make_delete_api_call(self, client, request_config):
del request_config # Unused.
client.delete_folder(self._url.bucket_name, self._url.resource_name)
class DeleteObjectTask(CloudDeleteTask):
"""Task to delete an object."""
def _make_delete_api_call(self, client, request_config):
client.delete_object(self._url, request_config)

View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Iterator for deleting buckets and objects."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.resources import resource_reference
from googlecloudsdk.command_lib.storage.tasks.rm import delete_task
from six.moves import queue
class DeleteTaskIteratorFactory:
"""Creates bucket and object delete task iterators."""
def __init__(self,
name_expansion_iterator,
task_status_queue=None,
user_request_args=None):
"""Initializes factory.
Args:
name_expansion_iterator (NameExpansionIterator): Iterable of wildcard
iterators to flatten.
task_status_queue (multiprocessing.Queue|None): Used for estimating total
workload from this iterator.
user_request_args (UserRequestArgs|None): Values for RequestConfig.
"""
self._name_expansion_iterator = name_expansion_iterator
self._task_status_queue = task_status_queue
self._user_request_args = user_request_args
self._bucket_delete_tasks = queue.Queue()
self._managed_folder_delete_tasks = queue.Queue()
self._folder_delete_tasks = queue.Queue()
self._object_delete_tasks = queue.Queue()
self._flat_wildcard_results_iterator = (
self._get_flat_wildcard_results_iterator())
def _get_flat_wildcard_results_iterator(self):
"""Iterates through items matching delete query, dividing into two lists.
Separates objects and buckets, so we can return two separate iterators.
Yields:
True if resource found.
"""
for name_expansion_result in self._name_expansion_iterator:
resource = name_expansion_result.resource
resource_url = resource.storage_url
# The wildcard iterator can return UnknownResources, so we use URLs to
# check for buckets.
if resource_url.is_bucket():
self._bucket_delete_tasks.put(
delete_task.DeleteBucketTask(resource_url)
)
elif isinstance(resource, resource_reference.ManagedFolderResource):
self._managed_folder_delete_tasks.put(
delete_task.DeleteManagedFolderTask(resource_url)
)
elif isinstance(resource, resource_reference.FolderResource):
self._folder_delete_tasks.put(
delete_task.DeleteFolderTask(resource_url)
)
else:
self._object_delete_tasks.put(
delete_task.DeleteObjectTask(
resource_url, user_request_args=self._user_request_args
)
)
yield True
def _resource_iterator(self, resource_queue):
"""Yields a resource from the queue."""
resource_count = 0
try:
while not resource_queue.empty() or next(
self._flat_wildcard_results_iterator
):
if not resource_queue.empty():
resource_count += 1
yield resource_queue.get()
except StopIteration:
pass
if resource_count:
progress_callbacks.workload_estimator_callback(
self._task_status_queue, resource_count
)
def bucket_iterator(self):
return self._resource_iterator(self._bucket_delete_tasks)
def managed_folder_iterator(self):
return self._resource_iterator(self._managed_folder_delete_tasks)
def folder_iterator(self):
return self._resource_iterator(self._folder_delete_tasks)
def object_iterator(self):
return self._resource_iterator(self._object_delete_tasks)

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for IAM policies on storage resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
from googlecloudsdk.api_lib.storage import api_factory
from googlecloudsdk.command_lib.storage import progress_callbacks
from googlecloudsdk.command_lib.storage.tasks import task
class _SetIamPolicyTask(task.Task):
"""Base class for tasks that set IAM policies."""
def __init__(self, url, policy):
"""Initializes task.
Args:
url (StorageUrl): Used to identify cloud resource to set policy on.
policy (object): Provider-specific data type. Currently, only available
for GCS so Apitools messages.Policy object. If supported for more
providers in the future, use a generic container.
"""
super(_SetIamPolicyTask, self).__init__()
self._url = url
self._policy = policy
@abc.abstractmethod
def _make_set_api_call(self, client):
"""Makes an API call to set the IAM policy. Overridden by children."""
pass
def execute(self, task_status_queue=None):
"""Executes task."""
client = api_factory.get_api(self._url.scheme)
new_policy = self._make_set_api_call(client)
if task_status_queue:
progress_callbacks.increment_count_callback(task_status_queue)
return task.Output(
additional_task_iterators=None,
messages=[task.Message(task.Topic.SET_IAM_POLICY, payload=new_policy)])
def __eq__(self, other):
if not isinstance(other, _SetIamPolicyTask):
return NotImplemented
return self._url == other._url and self._policy == other._policy
class SetBucketIamPolicyTask(_SetIamPolicyTask):
def _make_set_api_call(self, client):
return client.set_bucket_iam_policy(self._url.bucket_name, self._policy)
class SetManagedFolderIamPolicyTask(_SetIamPolicyTask):
def _make_set_api_call(self, client):
return client.set_managed_folder_iam_policy(
self._url.bucket_name, self._url.resource_name, self._policy
)
class SetObjectIamPolicyTask(_SetIamPolicyTask):
def _make_set_api_call(self, client):
return client.set_object_iam_policy(
self._url.bucket_name,
self._url.resource_name,
self._policy,
generation=self._url.generation,
)

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract operation class that command operations will inherit from.
Should typically be executed in a task iterator through
googlecloudsdk.command_lib.storage.tasks.task_executor.
Manual execution example:
>>> class CopyTask(Task):
... def __init__(self, src, dest):
... ...
>>> my_copy_task = new CopyTask('~/Desktop/memes.jpg', '/kernel/')
>>> my_copy_task.Execute()
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
import collections
import enum
from googlecloudsdk.core.util import debug_output
import six
class Topic(enum.Enum):
"""Categorizes different task messages."""
API_DOWNLOAD_RESULT = 'api_download_result'
# Set exit code to 1.
CHANGE_EXIT_CODE = 'change_exit_code'
CRC32C = 'crc32c'
CREATED_RESOURCE = 'created_resource'
ERROR = 'error'
# Set exit code to 1 and sends signal not to process new tasks
# (for parallel execution).
FATAL_ERROR = 'fatal_error'
MD5 = 'md5'
SET_IAM_POLICY = 'set_iam_policy'
UPLOADED_COMPONENT = 'uploaded_component'
# Holds information to be passed between tasks.
#
# Attributes:
# topic (Topic): The type of information this message holds.
# payload (Any): The information itself.
Message = collections.namedtuple(
'Message',
['topic', 'payload']
)
# Holds information returned from Task.Execute.
#
# Note that because information here is sent between processes, all data in this
# class must be picklable.
#
# Attributes:
# additional_task_iterators (Optional[Iterable[Iterable[Task]]]): Tasks to be
# executed such that all tasks in each Iterable[Task] are executed before
# any tasks in the next Iterable[Task]. Tasks within each Iterable[Task] are
# unordered. For example, if this value were the following:
#
# [
# [UploadObjectTask(), UploadObjectTask(), UploadObjectTask()],
# [ComposeObjectsTask()]
# ]
#
# All UploadObjectTasks should be completed before the ComposeObjectTask
# could begin, but the UploadObjectTasks could be executed in parallel.
# messages (Optional[Iterable[Message]]): Information to be passed to all
# dependent tasks.
Output = collections.namedtuple(
'Output',
['additional_task_iterators', 'messages']
)
class Task(six.with_metaclass(abc.ABCMeta, object)):
"""Abstract class to represent one command operation.
Attributes:
change_exit_code (bool): If True, failure of this task should update the
exit_code to 1. Defaults to True.
parallel_processing_key (Optional[Hashable]): Identifies a task during
execution. If this value is not None, the executor will skip this task if
another task being executed is using the same key. If this value is None,
the executor will not skip any tasks based on it.
received_messages (Iterable[Message]): Messages sent to this task
by its dependencies.
"""
def __init__(self):
self.change_exit_code = True
self.parallel_processing_key = None
self.received_messages = []
@abc.abstractmethod
def execute(self, task_status_queue=None):
"""Performs some work based on class attributes.
Args:
task_status_queue (multiprocessing.Queue): Used by task to report it
progress to a central location.
Returns:
An Output instance, or None.
"""
pass
def exit_handler(self, error=None, task_status_queue=None):
"""Task executor calls this method on a completed task before discarding it.
An example use case is a subclass that needs to report its final status and
if it failed or succeeded at some operation.
Args:
error (Exception|None): Task executor may pass an error object.
task_status_queue (multiprocessing.Queue): Used by task to report it
progress to a central location.
"""
del error, task_status_queue # Unused.
pass
def __repr__(self):
return debug_output.generic_repr(self)

View File

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements a buffer for tasks used in task_graph_executor.
See go/parallel-processing-in-gcloud-storage for more information.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
from six.moves import queue
BUFFER_HEADER = 'Buffer Contents:\n'
BUFFER_EMPTY_MESSAGE = 'Task Buffer is empty.'
class _PriorityWrapper:
"""Wraps a buffered task and tracks priority information.
Attributes:
task (Union[task.Task, str]): A buffered item. Expected to be a task or a
string (to handle shutdowns) when used by task_graph_executor.
priority (int): The priority of this task. A task with a lower value will be
executed before a task with a higher value, since queue.PriorityQueue uses
a min-heap.
"""
def __init__(self, task, priority):
self.task = task
self.priority = priority
def __lt__(self, other):
return self.priority < other.priority
class TaskBuffer:
"""Stores and prioritizes tasks.
The current implementation uses a queue.PriorityQueue under the hood, since
in experiments we found that the heap it maintains did not add too much
overhead. If it does end up being a bottleneck, the same API can be
implemented with a collections.deque.
"""
def __init__(self):
self._queue = queue.PriorityQueue()
def get(self):
"""Removes and returns an item from the buffer.
Calls to `get` block if there are no elements in the queue, and return
prioritized items before non-prioritized items.
Returns:
A buffered item. Expected to be a task or a string (to handle shutdowns)
when used by task_graph_executor.
"""
return self._queue.get().task
def put(self, task, prioritize=False):
"""Adds an item to the buffer.
Args:
task (Union[task.Task, str]): A buffered item. Expected to be a task or a
string (to handle shutdowns) when used by task_graph_executor.
prioritize (bool): Tasks added with prioritize=True will be returned by
`get` before tasks added with prioritize=False.
"""
priority = 0 if prioritize else 1
prioritized_item = _PriorityWrapper(task, priority)
self._queue.put(prioritized_item)
def size(self) -> int:
"""Returns the number of items in the buffer."""
return self._queue.qsize() # pylint: disable=protected-access
def __str__(self):
"""Returns a string representation of the buffer."""
if self.size() == 0:
return BUFFER_EMPTY_MESSAGE
# Use a List comprehension to create the string representation.
output_lines = [BUFFER_HEADER]
temp_queue = copy.deepcopy(self._queue.queue)
while temp_queue:
priority_wrapper = temp_queue.pop(0) # Get and remove the first item.
output_lines.append(str(priority_wrapper.task))
return '\n'.join(output_lines)

View File

@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function for executing the tasks contained in a Task Iterator.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import optimize_parameters_util
from googlecloudsdk.command_lib.storage import plurality_checkable_iterator
from googlecloudsdk.command_lib.storage.tasks import task_graph_executor
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.command_lib.storage.tasks import task_util
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
def _execute_tasks_sequential(task_iterator,
received_messages=None,
task_status_queue=None,
continue_on_error=False):
"""Executes task objects sequentially.
Args:
task_iterator (Iterable[task.Task]): An iterator for task objects.
received_messages (Iterable[task.Message]): Messages sent to each
task in task_iterator.
task_status_queue (multiprocessing.Queue|None): Used by task to report it
progress to a central location.
continue_on_error (bool): If True, execution will continue even if
errors occur.
Returns:
Tuple[int, Iterable[task.Message]]: The first element in the tuple
is the exit code and the second element is an iterable of messages
emitted by the tasks in task_iterator.
"""
exit_code = 0
messages_from_current_task_iterator = []
for task in task_iterator:
if received_messages is not None:
task.received_messages = received_messages
task_execution_error = None
try:
task_output = task.execute(task_status_queue=task_status_queue)
except Exception as e: # pylint: disable=broad-except
task_execution_error = e
if (
not isinstance(task_execution_error, errors.FatalError)
and continue_on_error
):
log.error(str(e))
if task.change_exit_code:
exit_code = 1
continue
else:
raise
finally:
task.exit_handler(task_execution_error, task_status_queue)
if task_output is None:
continue
if task_output.messages is not None:
messages_from_current_task_iterator.extend(task_output.messages)
if task_output.additional_task_iterators is not None:
messages_for_dependent_tasks = []
for additional_task_iterator in task_output.additional_task_iterators:
exit_code_from_dependent_tasks, messages_for_dependent_tasks = (
_execute_tasks_sequential(
additional_task_iterator,
messages_for_dependent_tasks,
task_status_queue=task_status_queue,
continue_on_error=continue_on_error))
exit_code = max(exit_code_from_dependent_tasks, exit_code)
return exit_code, messages_from_current_task_iterator
def execute_tasks(task_iterator,
parallelizable=False,
task_status_queue=None,
progress_manager_args=None,
continue_on_error=False):
"""Call appropriate executor.
Args:
task_iterator: An iterator for task objects.
parallelizable (boolean): Should tasks be executed in parallel.
task_status_queue (multiprocessing.Queue|None): Used by task to report its
progress to a central location.
progress_manager_args (task_status.ProgressManagerArgs|None):
Determines what type of progress indicator to display.
continue_on_error (bool): Only applicable for sequential mode. If True,
execution will continue even if errors occur.
Returns:
An integer indicating the exit_code. Zero indicates no fatal errors were
raised.
"""
task_util.require_python_3_5()
plurality_checkable_task_iterator = (
plurality_checkable_iterator.PluralityCheckableIterator(task_iterator))
optimize_parameters_util.detect_and_set_best_config(
is_estimated_multi_file_workload=(
plurality_checkable_task_iterator.is_plural()))
# Some tasks operate under the assumption that they will only be executed when
# parallelizable is True, and use should_use_parallelism to determine how they
# are executed.
if parallelizable and task_util.should_use_parallelism():
exit_code = task_graph_executor.TaskGraphExecutor(
plurality_checkable_task_iterator,
max_process_count=properties.VALUES.storage.process_count.GetInt(),
thread_count=properties.VALUES.storage.thread_count.GetInt(),
task_status_queue=task_status_queue,
progress_manager_args=progress_manager_args).run()
else:
with task_status.progress_manager(task_status_queue, progress_manager_args):
exit_code, _ = _execute_tasks_sequential(
plurality_checkable_task_iterator,
task_status_queue=task_status_queue,
continue_on_error=continue_on_error)
return exit_code

View File

@@ -0,0 +1,372 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements logic for tracking task dependencies in task_graph_executor.
See go/parallel-processing-in-gcloud-storage for more information.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import threading
from typing import List
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.core import log
INITIAL_INDENT_LEVEL = 2
TASK_GRAPH_HEADER = 'Task Graph:'
TASK_WRAPPER_ID = ' - Task ID: {}\n'
TASK_DETAILS = (
' - Task: {}\n'
' - Dependency Count: {}\n'
' - Dependent Task IDs: {}\n'
' - Is Submitted: {}\n'
)
class TaskWrapper:
"""Embeds a Task instance in a dependency graph.
Attributes:
id (Hashable): A unique identifier for this task wrapper.
task (googlecloudsdk.command_lib.storage.tasks.task.Task): An instance of a
task class.
dependency_count (int): The number of unexecuted dependencies this task has,
i.e. this node's in-degree in a graph where an edge from A to B indicates
that A must be executed before B.
dependent_task_ids (Optional[Iterable[Hashable]]): The id of the tasks that
require this task to be completed for their own completion. This value
should be None if no tasks depend on this one.
is_submitted (bool): True if this task has been submitted for execution.
"""
def __init__(self, task_id, task, dependent_task_ids):
self.id = task_id
self.task = task
self.dependency_count = 0
self.dependent_task_ids = dependent_task_ids
self.is_submitted = False
def __str__(self):
"""Returns a string representation of the TaskWrapper."""
return (
TASK_WRAPPER_ID.format(self.id) +
TASK_DETAILS.format(
self.task.__class__.__name__,
len(self.dependent_task_ids)
if self.dependent_task_ids else 0,
self.dependent_task_ids,
self.is_submitted
)
)
class InvalidDependencyError(errors.Error):
"""Raised on attempts to create an invalid dependency.
Invalid dependencies are self-dependencies and those that involve nodes that
do not exist.
"""
class TaskGraph:
"""Tracks dependencies between Task instances.
See googlecloudsdk.command_lib.storage.tasks.task.Task for the definition of
the Task class.
The public methods in this class are thread safe.
Attributes:
is_empty (threading.Event): is_empty.is_set() is True when the graph has no
tasks in it.
"""
def __init__(self, top_level_task_limit):
"""Initializes a TaskGraph instance.
Args:
top_level_task_limit (int): A top-level task is a task that no other tasks
depend on for completion (i.e. dependent_task_ids is None). Adding
top-level tasks with TaskGraph.add will block until there are fewer than
this number of top-level tasks in the graph.
"""
self.is_empty = threading.Event()
self.is_empty.set()
# Used to synchronize graph updates. This needs to be an RLock since this
# lock is acquired by each recursive call to TaskGraph.complete.
self._lock = threading.RLock()
# A dict[int, TaskWrapper]. Maps ids to task wrapper instances for tasks
# currently in the graph.
self._task_wrappers_in_graph = {}
# Acquired whenever a top-level task is added to the graph, and released
# when a top-level task is completed. This helps keep memory usage under
# control by limiting the graph size.
self._top_level_task_semaphore = threading.Semaphore(top_level_task_limit)
def add(self, task, dependent_task_ids=None):
"""Adds a task to the graph.
Args:
task (googlecloudsdk.command_lib.storage.tasks.task.Task): The task to be
added.
dependent_task_ids (Optional[List[Hashable]]): TaskWrapper.id attributes
for tasks already in the graph that require the task being added to
complete before being executed. This argument should be None for
top-level tasks, which no other tasks depend on.
Returns:
A TaskWrapper instance for the task passed into this function, or None if
task.parallel_processing_key was the same as another task's
parallel_processing_key.
Raises:
InvalidDependencyError if any id in dependent_task_ids is not in the
graph, or if a the add operation would have created a self-dependency.
"""
is_top_level_task = dependent_task_ids is None
if is_top_level_task:
self._top_level_task_semaphore.acquire()
with self._lock:
if task.parallel_processing_key is not None:
identifier = task.parallel_processing_key
else:
identifier = id(task)
if identifier in self._task_wrappers_in_graph:
if task.parallel_processing_key is not None:
log.status.Print(
'Skipping {} for {}. This can occur if a cp command results in '
'multiple writes to the same resource.'.format(
task.__class__.__name__, task.parallel_processing_key))
else:
log.status.Print(
'Skipping {}. This is probably because due to a bug that '
'caused it to be submitted for execution more than once.'.format(
task.__class__.__name__))
if is_top_level_task:
self._top_level_task_semaphore.release()
return
task_wrapper = TaskWrapper(identifier, task, dependent_task_ids)
for task_id in dependent_task_ids or []:
try:
self._task_wrappers_in_graph[task_id].dependency_count += 1
except KeyError:
raise InvalidDependencyError
self._task_wrappers_in_graph[task_wrapper.id] = task_wrapper
self.is_empty.clear()
return task_wrapper
def complete(self, task_wrapper):
"""Recursively removes a task and its parents from the graph if possible.
Tasks can be removed only if they have been submitted for execution and have
no dependencies. Removing a task can affect dependent tasks in one of two
ways, if the removal left the dependent tasks with no dependencies:
- If the dependent task has already been submitted, it can also be removed.
- If the dependent task has not already been submitted, it can be
submitted for execution.
This method removes all tasks that removing task_wrapper allows, and returns
all tasks that can be submitted after removing task_wrapper.
Args:
task_wrapper (TaskWrapper): The task_wrapper instance to remove.
Returns:
An Iterable[TaskWrapper] that yields tasks that are submittable after
completing task_wrapper.
"""
with self._lock:
if task_wrapper.dependency_count:
# This task has dependencies, so it cannot be removed from the graph and
# cannot be submitted for execution.
return []
if not task_wrapper.is_submitted:
# This task does not have dependencies and has not already been
# submitted, so it can now be executed.
return [task_wrapper]
# At this point, this task does not have dependencies and has already
# been submitted for execution. This means we can remove it from the
# graph.
del self._task_wrappers_in_graph[task_wrapper.id]
if task_wrapper.dependent_task_ids is None:
# We've completed a top-level task, so we should allow more to be added.
self._top_level_task_semaphore.release()
if not self._task_wrappers_in_graph:
self.is_empty.set()
return []
# After removing this task, some dependent tasks may now be executable.
# We can check this by decrementing all of their dependency counts and
# recursively calling this function.
submittable_tasks = []
for task_id in task_wrapper.dependent_task_ids:
dependent_task_wrapper = self._task_wrappers_in_graph[task_id]
dependent_task_wrapper.dependency_count -= 1
# Aggregates all of the submittable tasks found by recursive calls.
submittable_tasks += self.complete(dependent_task_wrapper)
return submittable_tasks
def update_from_executed_task(self, executed_task_wrapper, task_output):
r"""Updates the graph based on the output of an executed task.
If some googlecloudsdk.command_lib.storage.task.Task instance `a` returns
the following iterables of tasks: [[b, c], [d, e]], we need to update the
graph as follows to ensure they are executed appropriately.
/-- d <-\--/- b
a <-/ \/
\ /\
\-- e <-/--\- c
After making these updates, `b` and `c` are ready for submission. If a task
does not return any new tasks, then it will be removed from the graph,
potentially freeing up tasks that depend on it for execution.
See go/parallel-processing-in-gcloud-storage#heading=h.y4o7a9hcs89r for a
more thorough description of the updates this method performs.
Args:
executed_task_wrapper (task_graph.TaskWrapper): Contains information about
how a completed task fits into a dependency graph.
task_output (Optional[task.Output]): Additional tasks and
messages returned by the task in executed_task_wrapper.
Returns:
An Iterable[task_graph.TaskWrapper] containing tasks that are ready to be
executed after performing graph updates.
"""
with self._lock:
if (task_output is not None
and task_output.messages is not None
and executed_task_wrapper.dependent_task_ids is not None):
for task_id in executed_task_wrapper.dependent_task_ids:
dependent_task_wrapper = self._task_wrappers_in_graph[task_id]
dependent_task_wrapper.task.received_messages.extend(
task_output.messages)
if task_output is None or not task_output.additional_task_iterators:
# The executed task did not return new tasks, so the only ones newly
# ready for execution will be those freed up after removing the executed
# task.
return self.complete(executed_task_wrapper)
parent_tasks_for_next_layer = [executed_task_wrapper]
# Tasks return additional tasks in the order they should be executed in,
# but adding them to the graph is more easily done in reverse.
for task_iterator in reversed(task_output.additional_task_iterators):
dependent_task_ids = [
task_wrapper.id for task_wrapper in parent_tasks_for_next_layer
]
parent_tasks_for_next_layer = []
for task in task_iterator:
task_wrapper = self.add(task, dependent_task_ids=dependent_task_ids)
if task_wrapper is not None:
parent_tasks_for_next_layer.append(task_wrapper)
# If the dependent tasks are skipped, then the parent tasks needs to be
# marked as completed
if not parent_tasks_for_next_layer:
self.complete(executed_task_wrapper)
return parent_tasks_for_next_layer
def __str__(self):
"""Returns a string representation of the TaskGraph."""
output: List[str] = [
TASK_GRAPH_HEADER,
f' - Empty: {self.is_empty.is_set()}',
f' - Task Wrappers: {len(self._task_wrappers_in_graph)}',
]
if self._task_wrappers_in_graph:
printed_tasks = set()
output.extend(
self._print_task_wrapper_recursive(
self._task_wrappers_in_graph.values(),
INITIAL_INDENT_LEVEL,
printed_tasks,
)
)
else:
output.append('No tasks in the graph to print.')
return '\n'.join(output)
def _print_task_wrapper_recursive(
self, task_wrappers, indent_level, printed_tasks
) -> List[str]:
"""Recursively yields task wrappers and their dependencies.
Example:
Suppose we have task wrappers representing tasks with dependencies:
task_wrapper1 = TaskWrapper(id='task1',
dependent_task_ids=['task2', 'task3']),
task_wrapper2 = TaskWrapper(id='task2', dependent_task_ids=['task4'])
task_wrapper3 = TaskWrapper(id='task3', dependent_task_ids=[])
task_wrapper4 = TaskWrapper(id='task4', dependent_task_ids=[])
task_wrappers = [task_wrapper1, task_wrapper2,
task_wrapper3, task_wrapper4]
Calling _print_task_wrapper_recursive(task_wrappers, 0, set())
would produce:
['task1',
' task2',
' task4',
' task3']
This shows the tasks and their dependencies formatted with appropriate
indentation levels.
Args:
task_wrappers (list): List of task wrappers to print.
indent_level (int): Current level of indentation for formatting.
printed_tasks (set): Set of task IDs that have already been printed.
Yields:
List of formatted strings representing the task wrappers
and their dependencies.
"""
for task_wrapper in task_wrappers:
if task_wrapper.id not in printed_tasks:
printed_tasks.add(task_wrapper.id)
yield str(task_wrapper)
if task_wrapper.dependent_task_ids:
dependent_task_wrappers = [
self._task_wrappers_in_graph[task_id]
for task_id in task_wrapper.dependent_task_ids
]
yield from self._print_task_wrapper_recursive(
dependent_task_wrappers, indent_level + 2, printed_tasks)

View File

@@ -0,0 +1,245 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for debugging task graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
import sys
import threading
import traceback
from typing import Dict, Iterator
from googlecloudsdk.command_lib.storage.tasks import task_buffer
from googlecloudsdk.command_lib.storage.tasks import task_graph as task_graph_module
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import files
def is_task_graph_debugging_enabled() -> bool:
"""Whether task graph debugging is enabled.
Returns:
bool: True if task graph debugging is enabled else False.
"""
return properties.VALUES.storage.enable_task_graph_debugging.GetBool()
def get_time_interval_between_snapshots() -> int:
"""Returns the time interval in seconds between two consecutive snapshots."""
return (
properties.VALUES.storage.task_graph_debugging_snapshot_duration.GetInt()
)
def yield_stack_traces() -> Iterator[str]:
"""Retrieve stack traces for all the threads in the current process."""
# pylint:disable=protected-access
# There does not appear to be another way to collect the stack traces
# for all running threads.
for thread_id, stack in sys._current_frames().items():
yield f'\n# Traceback for thread: {thread_id}'
for filename, line_number, name, text in traceback.extract_stack(stack):
yield f'File: "{filename}", line {line_number}, in {name}'
if text:
yield f' {text.strip()}'
def _yield_management_thread_stack_traces(
name_to_thread: Dict[str, threading.Thread],
alive_thread_id_to_name: Dict[int, str],
) -> Iterator[str]:
"""Yields the stack traces of the alive management threads."""
for thread_name, thread in name_to_thread.items():
if thread.is_alive():
alive_thread_id_to_name[thread.ident] = thread_name
all_threads_stack_traces = yield_stack_traces()
current_thread_id = None
thread_id_pattern = re.compile(r'^\n# Traceback for thread:(.*)')
for line in all_threads_stack_traces:
if thread_id_match := thread_id_pattern.match(line):
current_thread_id = int(thread_id_match.group(1))
if (
current_thread_id in alive_thread_id_to_name
): # printing the stack traces of only the alive management threads.
if thread_id_match:
yield (
'\n# Traceback for'
f' thread:{alive_thread_id_to_name[current_thread_id]}'
)
yield line
for thread_name, thread in name_to_thread.items():
if thread.ident not in alive_thread_id_to_name:
yield (
f'\n# Thread {thread_name} is not running. Cannot get stack trace at'
' the moment.'
)
def print_management_thread_stacks(
management_threads_name_to_function: Dict[str, threading.Thread],
):
"""Prints stack traces of the management threads."""
log.status.Print(
'Initiating stack trace information of the management threads.'
)
alive_thread_id_to_name = {}
stack_traces = _yield_management_thread_stack_traces(
management_threads_name_to_function, alive_thread_id_to_name
)
for line in stack_traces:
log.status.Print(line)
def print_worker_thread_stack_traces(stack_trace_file_path):
"""Prints stack traces of the worker threads."""
try:
stack_traces = files.ReadFileContents(stack_trace_file_path)
except IOError as e:
log.error(f'Error reading stack trace file: {e}')
log.status.Print('No stack traces could be retrieved.')
return
if stack_traces:
log.status.Print('Printing stack traces for worker threads:')
# Split contents into lines and print each line.
for line in stack_traces.splitlines():
log.status.Print(line.strip())
else:
log.status.Print('No stack traces found. No worker threads running.')
def print_queue_size(task_queue, task_status_queue, task_output_queue):
"""Prints the size of the queues."""
log.status.Print(f'Task Queue size: {task_queue.qsize()}')
log.status.Print(f'Task Status Queue size: {task_status_queue.qsize()}')
log.status.Print(f'Task Output Queue size: {task_output_queue.qsize()}')
def _is_task_graph_empty(task_graph: task_graph_module.TaskGraph) -> bool:
"""Checks if the task graph is empty."""
return task_graph.is_empty.is_set()
def _is_task_buffer_empty(task__buffer: task_buffer.TaskBuffer) -> bool:
"""Checks if the task buffer is empty."""
return task__buffer.size() == 0
def task_graph_debugger_worker(
management_threads_name_to_function: Dict[str, threading.Thread],
stack_trace_file: str,
task_graph: task_graph_module.TaskGraph,
task__buffer: task_buffer.TaskBuffer,
delay_seconds: int,
):
"""The main worker function for the task graph debugging framework.
Prints the stack traces of the management threads involved namely
iterator_to_buffer, buffer_to_queue and task_output_handler.Captures and
prints the contents of the task graph and task buffer.
Also prints the stack traces of the worker threads if they are running at the
particular snapshot taken.
Args:
management_threads_name_to_function: A dictionary of management thread name
to the thread function.
stack_trace_file: Path to the file containing the stack traces of the worker
threads.
task_graph: The task graph object.
task__buffer: The task buffer object.
delay_seconds: The time interval between two consecutive snapshots.
"""
is_task_graph_empty = _is_task_graph_empty(task_graph)
is_task_buffer_empty = _is_task_buffer_empty(task__buffer)
# Set it to true to ensure that the debugger worker prints the status
# atleast once.
is_some_management_thread_alive = True
while (
is_some_management_thread_alive
or not is_task_graph_empty
or not is_task_buffer_empty
):
print_management_thread_stacks(management_threads_name_to_function)
print_worker_thread_stack_traces(stack_trace_file)
log.status.Print(str(task_graph))
log.status.Print(str(task__buffer))
is_task_graph_empty = _is_task_graph_empty(task_graph)
is_task_buffer_empty = _is_task_buffer_empty(task__buffer)
is_some_management_thread_alive = False
for thread in management_threads_name_to_function.values():
if thread.is_alive():
is_some_management_thread_alive = True
break
# Wait for the delay_seconds to pass before taking the next snapshot
# if conditions are met.
event = threading.Event()
event.wait(delay_seconds)
def start_thread_for_task_graph_debugging(
management_threads_name_to_function: Dict[str, threading.Thread],
stack_trace_file: str,
task_graph: task_graph_module.TaskGraph,
task__buffer: task_buffer.TaskBuffer,
):
"""Starts a thread for task graph debugging."""
try:
thread_for_task_graph_debugging = threading.Thread(
target=task_graph_debugger_worker,
args=(
management_threads_name_to_function,
stack_trace_file,
task_graph,
task__buffer,
get_time_interval_between_snapshots(),
),
)
thread_for_task_graph_debugging.start()
except Exception as e: # pylint: disable=broad-except
log.error(f'Error starting thread: {e}')
def write_stack_traces_to_file(
stack_traces: Iterator[str], stack_trace_file_path: str
):
"""Writes stack traces to a file."""
if not stack_trace_file_path:
return
try:
stripped_stack_entries = []
for entry in stack_traces:
stripped_entry = entry.strip()
if stripped_entry:
stripped_stack_entries.append(stripped_entry)
content = '\n'.join(stripped_stack_entries)
files.WriteFileContents(stack_trace_file_path, content)
except Exception as e: # pylint: disable=broad-except
log.error(f'An error occurred while writing stack trace file: {e}')

View File

@@ -0,0 +1,734 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements parallel task execution for the storage surface.
See go/parallel-processing-in-gcloud-storage for more information.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import contextlib
import functools
import multiprocessing
import signal as signal_lib
import sys
import tempfile
import threading
from googlecloudsdk.api_lib.storage.gcs_json import patch_apitools_messages
from googlecloudsdk.command_lib import crash_handling
from googlecloudsdk.command_lib.storage import encryption_util
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage.tasks import task
from googlecloudsdk.command_lib.storage.tasks import task_buffer
from googlecloudsdk.command_lib.storage.tasks import task_graph as task_graph_module
from googlecloudsdk.command_lib.storage.tasks import task_graph_debugger
from googlecloudsdk.command_lib.storage.tasks import task_status
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import transport
from googlecloudsdk.core.console import console_io
from googlecloudsdk.core.credentials import creds_context_managers
from googlecloudsdk.core.util import platforms
from six.moves import queue
# TODO(b/171296237): Remove this when fixes are submitted in apitools.
patch_apitools_messages.patch()
if sys.version_info.major == 2:
# multiprocessing.get_context is only available in Python 3. We don't support
# Python 2, but some of our code still runs at import in Python 2 tests, so
# we need to provide a value here.
multiprocessing_context = multiprocessing
else:
_method = properties.VALUES.storage.multiprocessing_default_method.Get()
if _method is not None:
multiprocessing_context = multiprocessing.get_context(method=_method)
else:
_should_force_spawn = (
# On MacOS, fork is unsafe: https://bugs.python.org/issue33725. The
# default start method is spawn on versions >= 3.8, but we need to set
# it explicitly for older versions.
platforms.OperatingSystem.Current() is platforms.OperatingSystem.MACOSX
or
# On Linux, fork causes issues when mTLS is enabled:
# go/ecp-gcloud-storage
# The default start method on Linux is fork, hence we will set it to
# spawn when client certificate authentication (mTLS) is enabled.
(
properties.VALUES.context_aware.use_client_certificate.GetBool()
and platforms.OperatingSystem.Current()
is platforms.OperatingSystem.LINUX
)
)
if _should_force_spawn:
multiprocessing_context = multiprocessing.get_context(method='spawn')
# TODO(b/438968865): Re-evaluate this workaround once the root cause of the
# forkserver-related test failures in Python 3.14 is understood and
# addressed.
elif (sys.version_info.major == 3 and sys.version_info.minor >= 14) and (
platforms.OperatingSystem.Current() is platforms.OperatingSystem.LINUX
):
# Force 'fork' start method for Linux.
multiprocessing_context = multiprocessing.get_context(method='fork')
else:
# Force 'fork' start method unconditionally for processes.
# WARNING: Using 'fork' is unsafe when threads are running or with
# certain C-extension libraries (like those used for mTLS or on macOS).
# This override removes the safety checks present in the original code.
multiprocessing_context = multiprocessing.get_context()
_TASK_QUEUE_LOCK = threading.Lock()
# TODO(b/203819260): Check if this lock can be removed on Windows, since message
# patches are applied above.
@contextlib.contextmanager
def _task_queue_lock():
"""Context manager which acquires a lock when queue.get is unsafe.
On Python 3.5 with spawn enabled, a race condition affects unpickling
objects in queue.get calls. This manifests as an AttributeError intermittently
thrown by ForkingPickler.loads, e.g.:
AttributeError: Can't get attribute 'FileDownloadTask' on <module
'googlecloudsdk.command_lib.storage.tasks.cp.file_download_task' from
'googlecloudsdk/command_lib/storage/tasks/cp/file_download_task.py'
Adding a lock around queue.get calls using this context manager resolves the
issue.
Yields:
None, but acquires a lock which is released on exit.
"""
get_is_unsafe = (
sys.version_info.major == 3 and sys.version_info.minor <= 5
and multiprocessing_context.get_start_method() == 'spawn'
)
try:
if get_is_unsafe:
_TASK_QUEUE_LOCK.acquire()
yield
finally:
if get_is_unsafe:
_TASK_QUEUE_LOCK.release()
# When threads get this value, they should prepare to exit.
#
# Threads should check for this value with `==` and not `is`, since the pickling
# carried out by multiprocessing.Queue may cause `is` to incorrectly return
# False.
#
# When the executor is shutting down, this value is added to
# TaskGraphExecutor._executable_tasks and is passed to
# TaskGraphExecutor._task_queue.
_SHUTDOWN = 'SHUTDOWN'
_CREATE_WORKER_PROCESS = 'CREATE_WORKER_PROCESS'
class _DebugSignalHandler:
"""Signal handler for collecting debug information."""
def __init__(self):
"""Initializes the debug signal handler."""
if (
platforms.OperatingSystem.Current()
is not platforms.OperatingSystem.WINDOWS
):
self._debug_signal = signal_lib.SIGUSR1
def _debug_handler(
self, signal_number: int = None, frame: object = None
) -> None:
"""Logs stack traces of running threads.
Args:
signal_number: Signal number.
frame: Frame object.
"""
del signal_number, frame # currently unused
log.debug('Initiating crash debug information data collection.')
stack_traces = []
stack_traces.extend(task_graph_debugger.yield_stack_traces())
for line in stack_traces:
log.debug(line)
def install(self):
"""Installs the debug signal handler."""
if platforms.OperatingSystem.Current() is platforms.OperatingSystem.WINDOWS:
return # Not supported for windows systems.
try:
self._original_signal_handler = signal_lib.getsignal(self._debug_signal)
signal_lib.signal(self._debug_signal, self._debug_handler)
except ValueError:
pass # Can be run from the main thread only.
def terminate(self):
"""Restores the original signal handler.
This method should be called when the debug signal handler is no longer
needed.
"""
if platforms.OperatingSystem.Current() is platforms.OperatingSystem.WINDOWS:
return # Not supported for windows systems.
try:
if hasattr(self, '_original_signal_handler'):
signal_lib.signal(self._debug_signal, self._original_signal_handler)
except ValueError:
pass # Can be run from the main thread only.
class SharedProcessContext:
"""Context manager used to collect and set global state."""
def __init__(self):
"""Collects global state in the main process."""
if multiprocessing_context.get_start_method() == 'fork':
return
self._environment_variables = execution_utils.GetToolEnv()
self._creds_context_manager = (
creds_context_managers.CredentialProvidersManager())
self._key_store = encryption_util._key_store
self._invocation_id = transport.INVOCATION_ID
def __enter__(self):
"""Sets global state in child processes."""
if multiprocessing_context.get_start_method() == 'fork':
return
self._environment_context_manager = execution_utils.ReplaceEnv(
**self._environment_variables)
self._environment_context_manager.__enter__()
self._creds_context_manager.__enter__()
encryption_util._key_store = self._key_store
transport.INVOCATION_ID = self._invocation_id
# Passing None causes log settings to be refreshed based on property values.
log.SetUserOutputEnabled(None)
log.SetVerbosity(None)
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Cleans up global state in child processes."""
if multiprocessing_context.get_start_method() == 'fork':
return
self._environment_context_manager.__exit__(
exc_type, exc_value, exc_traceback)
self._creds_context_manager.__exit__(exc_type, exc_value, exc_traceback)
@crash_handling.CrashManager
def _thread_worker(task_queue, task_output_queue, task_status_queue,
idle_thread_count):
"""A consumer thread run in a child process.
Args:
task_queue (multiprocessing.Queue): Holds task_graph.TaskWrapper instances.
task_output_queue (multiprocessing.Queue): Sends information about completed
tasks back to the main process.
task_status_queue (multiprocessing.Queue|None): Used by task to report it
progress to a central location.
idle_thread_count (multiprocessing.Semaphore): Keeps track of how many
threads are busy. Useful for spawning new workers if all threads are busy.
"""
while True:
with _task_queue_lock():
task_wrapper = task_queue.get()
if task_wrapper == _SHUTDOWN:
break
idle_thread_count.acquire()
task_execution_error = None
try:
task_output = task_wrapper.task.execute(
task_status_queue=task_status_queue)
# pylint: disable=broad-except
# If any exception is raised, it will prevent the executor from exiting.
except Exception as exception:
task_execution_error = exception
log.error(exception)
log.debug(exception, exc_info=sys.exc_info())
if isinstance(exception, errors.FatalError):
task_output = task.Output(
additional_task_iterators=None,
messages=[task.Message(topic=task.Topic.FATAL_ERROR, payload={})])
elif task_wrapper.task.change_exit_code:
task_output = task.Output(
additional_task_iterators=None,
messages=[
task.Message(topic=task.Topic.CHANGE_EXIT_CODE, payload={})
])
else:
task_output = None
# pylint: enable=broad-except
finally:
task_wrapper.task.exit_handler(task_execution_error, task_status_queue)
task_output_queue.put((task_wrapper, task_output))
idle_thread_count.release()
@crash_handling.CrashManager
def _process_worker(
task_queue,
task_output_queue,
task_status_queue,
thread_count,
idle_thread_count,
shared_process_context,
stack_trace_file_path
):
"""Starts a consumer thread pool.
Args:
task_queue (multiprocessing.Queue): Holds task_graph.TaskWrapper instances.
task_output_queue (multiprocessing.Queue): Sends information about completed
tasks back to the main process.
task_status_queue (multiprocessing.Queue|None): Used by task to report it
progress to a central location.
thread_count (int): Number of threads the process should spawn.
idle_thread_count (multiprocessing.Semaphore): Passed on to worker threads.
shared_process_context (SharedProcessContext): Holds values from global
state that need to be replicated in child processes.
stack_trace_file_path (str): File path to write stack traces to.
"""
threads = []
with shared_process_context:
for _ in range(thread_count):
thread = threading.Thread(
target=_thread_worker,
args=(
task_queue,
task_output_queue,
task_status_queue,
idle_thread_count,
),
)
thread.start()
threads.append(thread)
# TODO: b/354829547 - Update the function to catch the updated stack traces
# of the already running worker threads while a new worker process
# is not created.
if task_graph_debugger.is_task_graph_debugging_enabled():
stack_trace = task_graph_debugger.yield_stack_traces()
task_graph_debugger.write_stack_traces_to_file(
stack_trace, stack_trace_file_path
)
for thread in threads:
thread.join()
@crash_handling.CrashManager
def _process_factory(
task_queue,
task_output_queue,
task_status_queue,
thread_count,
idle_thread_count,
signal_queue,
shared_process_context,
stack_trace_file_path
):
"""Create worker processes.
This factory must run in a separate process to avoid deadlock issue,
see go/gcloud-storage-deadlock-issue/. Although we are adding one
extra process by doing this, it will remain idle once all the child worker
processes are created. Thus, it does not add noticable burden on the system.
Args:
task_queue (multiprocessing.Queue): Holds task_graph.TaskWrapper instances.
task_output_queue (multiprocessing.Queue): Sends information about completed
tasks back to the main process.
task_status_queue (multiprocessing.Queue|None): Used by task to report it
progress to a central location.
thread_count (int): Number of threads the process should spawn.
idle_thread_count (multiprocessing.Semaphore): Passed on to worker threads.
signal_queue (multiprocessing.Queue): Queue used by parent process to
signal when a new child worker process must be created.
shared_process_context (SharedProcessContext): Holds values from global
state that need to be replicated in child processes.
stack_trace_file_path (str): File path to write stack traces to.
"""
processes = []
while True:
# We receive one signal message for each process to be created.
signal = signal_queue.get()
if signal == _SHUTDOWN:
for _ in processes:
for _ in range(thread_count):
task_queue.put(_SHUTDOWN)
break
elif signal == _CREATE_WORKER_PROCESS:
for _ in range(thread_count):
idle_thread_count.release()
process = multiprocessing_context.Process(
target=_process_worker,
args=(
task_queue,
task_output_queue,
task_status_queue,
thread_count,
idle_thread_count,
shared_process_context,
stack_trace_file_path,
),
)
processes.append(process)
log.debug('Adding 1 process with {} threads.'
' Total processes: {}. Total threads: {}.'.format(
thread_count, len(processes),
len(processes) * thread_count))
process.start()
else:
raise errors.Error('Received invalid signal for worker '
'process creation: {}'.format(signal))
for process in processes:
process.join()
def _store_exception(target_function):
"""Decorator for storing exceptions raised from the thread targets.
Args:
target_function (function): Thread target to decorate.
Returns:
Decorator function.
"""
@functools.wraps(target_function)
def wrapper(self, *args, **kwargs):
try:
target_function(self, *args, **kwargs)
# pylint:disable=broad-except
except Exception as e:
# pylint:enable=broad-except
if not isinstance(self, TaskGraphExecutor):
# Storing of exception is only allowed for TaskGraphExecutor.
raise
with self.thread_exception_lock:
if self.thread_exception is None:
log.debug('Storing error to raise later: %s', e)
self.thread_exception = e
else:
# This indicates that the exception has been already stored for
# another thread. We will simply log the traceback in this
# case, since raising the error is not going to be handled by the
# main thread anyway.
log.error(e)
log.debug(e, exc_info=sys.exc_info())
return wrapper
class TaskGraphExecutor:
"""Executes an iterable of command_lib.storage.tasks.task.Task instances."""
def __init__(
self,
task_iterator,
max_process_count=multiprocessing.cpu_count(),
thread_count=4,
task_status_queue=None,
progress_manager_args=None,
):
"""Initializes a TaskGraphExecutor instance.
No threads or processes are started by the constructor.
Args:
task_iterator (Iterable[command_lib.storage.tasks.task.Task]): Task
instances to execute.
max_process_count (int): The number of processes to start.
thread_count (int): The number of threads to start per process.
task_status_queue (multiprocessing.Queue|None): Used by task to report its
progress to a central location.
progress_manager_args (task_status.ProgressManagerArgs|None):
Determines what type of progress indicator to display.
"""
self._task_iterator = iter(task_iterator)
self._max_process_count = max_process_count
self._thread_count = thread_count
self._task_status_queue = task_status_queue
self._progress_manager_args = progress_manager_args
self._process_count = 0
self._idle_thread_count = multiprocessing_context.Semaphore(value=0)
self._worker_count = self._max_process_count * self._thread_count
# Sends task_graph.TaskWrapper instances to child processes.
# Size must be 1. go/lazy-process-spawning-addendum.
self._task_queue = multiprocessing_context.Queue(maxsize=1)
# Sends information about completed tasks to the main process.
self._task_output_queue = multiprocessing_context.Queue(
maxsize=self._worker_count)
# Queue for informing worker_process_creator to create a new process.
self._signal_queue = multiprocessing_context.Queue(
maxsize=self._worker_count + 1)
# Tracks dependencies between tasks in the executor to help ensure that
# tasks returned by executed tasks are completed in the correct order.
self._task_graph = task_graph_module.TaskGraph(
top_level_task_limit=2 * self._worker_count)
# Holds tasks without any dependencies.
self._executable_tasks = task_buffer.TaskBuffer()
# For storing exceptions.
self.thread_exception = None
self.thread_exception_lock = threading.Lock()
self._accepting_new_tasks = True
self._exit_code = 0
self._debug_handler = _DebugSignalHandler()
self.stack_trace_file_path = None
if task_graph_debugger.is_task_graph_debugging_enabled():
try:
with tempfile.NamedTemporaryFile(
prefix='stack_trace', suffix='.txt', delete=False
) as f:
self.stack_trace_file_path = f.name
except IOError as e:
log.error('Error creating stack trace file: %s', e)
self._management_threads_name_to_function = {}
def _add_worker_process(self):
"""Signal the worker process spawner to create a new process."""
self._signal_queue.put(_CREATE_WORKER_PROCESS)
self._process_count += 1
@_store_exception
def _get_tasks_from_iterator(self):
"""Adds tasks from self._task_iterator to the executor.
This involves adding tasks to self._task_graph, marking them as submitted,
and adding them to self._executable_tasks.
"""
while self._accepting_new_tasks:
try:
task_object = next(self._task_iterator)
except StopIteration:
break
task_wrapper = self._task_graph.add(task_object)
if task_wrapper is None:
# self._task_graph rejected the task.
continue
task_wrapper.is_submitted = True
# Tasks from task_iterator should have a lower priority than tasks that
# are spawned by other tasks. This helps keep memory usage under control
# when a workload's task graph has a large branching factor.
self._executable_tasks.put(task_wrapper, prioritize=False)
@_store_exception
def _add_executable_tasks_to_queue(self):
"""Sends executable tasks to consumer threads in child processes."""
task_wrapper = None
while True:
if task_wrapper is None:
task_wrapper = self._executable_tasks.get()
if task_wrapper == _SHUTDOWN:
break
reached_process_limit = self._process_count >= self._max_process_count
try:
self._task_queue.put(task_wrapper, block=reached_process_limit)
task_wrapper = None
except queue.Full:
if self._idle_thread_count.acquire(block=False):
# Idle worker will take a task. Restore semaphore count.
self._idle_thread_count.release()
else:
self._add_worker_process()
@_store_exception
def _handle_task_output(self):
"""Updates a dependency graph based on information from executed tasks."""
while True:
output = self._task_output_queue.get()
if output == _SHUTDOWN:
break
executed_task_wrapper, task_output = output
if task_output and task_output.messages:
for message in task_output.messages:
if message.topic in (task.Topic.CHANGE_EXIT_CODE,
task.Topic.FATAL_ERROR):
self._exit_code = 1
if message.topic == task.Topic.FATAL_ERROR:
self._accepting_new_tasks = False
submittable_tasks = self._task_graph.update_from_executed_task(
executed_task_wrapper, task_output)
for task_wrapper in submittable_tasks:
task_wrapper.is_submitted = True
self._executable_tasks.put(task_wrapper)
def _clean_worker_process_spawner(self, worker_process_spawner):
"""Common method which carries out the required steps to clean up worker processes.
Args:
worker_process_spawner (Process): The worker parent process that we need
to clean up.
"""
# Shutdown all the workers.
if worker_process_spawner.is_alive():
self._signal_queue.put(_SHUTDOWN)
worker_process_spawner.join()
# Restore the debug signal handler.
self._debug_handler.terminate()
def run(self):
"""Executes tasks from a task iterator in parallel.
Returns:
An integer indicating the exit code. Zero indicates no fatal errors were
raised.
"""
shared_process_context = SharedProcessContext()
self._debug_handler.install()
worker_process_spawner = multiprocessing_context.Process(
target=_process_factory,
args=(
self._task_queue,
self._task_output_queue,
self._task_status_queue,
self._thread_count,
self._idle_thread_count,
self._signal_queue,
shared_process_context,
self.stack_trace_file_path
),
)
worker_process_cleaned_up = False
try:
worker_process_spawner.start()
# It is now safe to start the progress_manager thread, since new processes
# are started by a child process.
with task_status.progress_manager(
self._task_status_queue, self._progress_manager_args
):
try:
self._add_worker_process()
get_tasks_from_iterator_thread = threading.Thread(
target=self._get_tasks_from_iterator
)
add_executable_tasks_to_queue_thread = threading.Thread(
target=self._add_executable_tasks_to_queue
)
handle_task_output_thread = threading.Thread(
target=self._handle_task_output
)
get_tasks_from_iterator_thread.start()
add_executable_tasks_to_queue_thread.start()
handle_task_output_thread.start()
if task_graph_debugger.is_task_graph_debugging_enabled():
self._management_threads_name_to_function[
'get_tasks_from_iterator'
] = get_tasks_from_iterator_thread
self._management_threads_name_to_function[
'add_executable_tasks_to_queue'
] = add_executable_tasks_to_queue_thread
self._management_threads_name_to_function['handle_task_output'] = (
handle_task_output_thread
)
task_graph_debugger.start_thread_for_task_graph_debugging(
self._management_threads_name_to_function,
self.stack_trace_file_path,
self._task_graph,
self._executable_tasks,
)
get_tasks_from_iterator_thread.join()
try:
self._task_graph.is_empty.wait()
except console_io.OperationCancelledError:
# If user hits ctrl-c, there will be no thread to pop tasks from the
# graph. Python garbage collection will remove unstarted tasks in
# the graph if we skip this endless wait.
pass
self._executable_tasks.put(_SHUTDOWN)
self._task_output_queue.put(_SHUTDOWN)
handle_task_output_thread.join()
add_executable_tasks_to_queue_thread.join()
finally:
# By calling the clean in the finally block, we ensure that the
# progress manager exit is called first.
# We also handle the scenario where an exception may be thrown by the
# progress manager it self.
self._clean_worker_process_spawner(worker_process_spawner)
worker_process_cleaned_up = True
except Exception as e: # pylint: disable=broad-exception-caught
# In case we get an exception occurs while spinning up the worker process
# spawner or during start of progress manager context, we need to
# do a clean up, hence we use the following method which carries out
# the neccesary steps.
# Note that the clean up only occurs if an exception occurs. There is
# another finally block within the progress manager context which will
# execute if there is any exception or in case of compleition of internal
# logic. If that is invoked, there is a small chance of this block being
# invoked as well, but for that, we have the worker process clean-up flag.
if not worker_process_cleaned_up:
self._clean_worker_process_spawner(worker_process_spawner)
# Raise it back as we still want main process to exit
raise e
# Queue close calls need to be outside the worker process spawner context
# manager since the task queue need to be open for the shutdown logic.
self._task_queue.close()
self._task_output_queue.close()
with self.thread_exception_lock:
if self.thread_exception:
raise self.thread_exception # pylint: disable=raising-bad-type
return self._exit_code

View File

@@ -0,0 +1,417 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for monitoring and reporting task statuses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
import collections
import datetime
import enum
import threading
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import manifest_util
from googlecloudsdk.command_lib.storage import metrics_util
from googlecloudsdk.command_lib.storage import thread_messages
from googlecloudsdk.core import log
from googlecloudsdk.core.console import progress_tracker
from googlecloudsdk.core.util import scaled_integer
import six
# Recalculate throughput everytime last message time - window_start_time
# is greater than this time threshold.
_THROUGHPUT_WINDOW_THRESHOLD_SECONDS = 3
class OperationName(enum.Enum):
DOWNLOADING = 'Downloading'
INTRA_CLOUD_COPYING = 'Intra-Cloud Copying'
DAISY_CHAIN_COPYING = 'Daisy Chain Copying'
UPLOADING = 'Uploading'
class IncrementType(enum.Enum):
INTEGER = 'INTEGER'
FILES_AND_BYTES = 'FILES_AND_BYTES'
ProgressManagerArgs = collections.namedtuple(
'ProgressManagerArgs', ['increment_type', 'manifest_path'])
class FileProgress:
"""Holds progress information for file being copied.
Attributes:
component_progress (dict<int,int>): Records bytes copied per component. If
not multi-component copy (e.g. "sliced download"), there will only be one
component.
start_time (datetime|None): Needed if writing file copy results to manifest.
total_bytes_copied (int|None): Sum of bytes copied for each component.
Needed because components are popped when completed, but we don't want to
lose info on them if writing to the manifest.
error_occurred (bool): Whether an error occurred during the operation.
"""
def __init__(
self,
component_count,
start_time=None,
total_bytes_copied=None,
error_occurred=False,
):
self.component_progress = {i: 0 for i in range(component_count)}
self.start_time = start_time
self.total_bytes_copied = total_bytes_copied
self.error_occurred = error_occurred
def _get_formatted_throughput(bytes_processed, time_delta):
throughput_bytes = max(bytes_processed / time_delta, 0)
return scaled_integer.FormatBinaryNumber(
throughput_bytes, decimal_places=1) + '/s'
class _StatusTracker(six.with_metaclass(abc.ABCMeta, object)):
"""Abstract class for tracking and displaying operation progress."""
@abc.abstractmethod
def _get_status_string(self):
"""Generates string to illustrate progress to the user."""
pass
def _get_done_string(self):
"""Generates string for when StatusTracker exits."""
return '\n'
@abc.abstractmethod
def add_message(self, status_message):
"""Processes task status message for printing and aggregation.
Args:
status_message (thread_messages.*): Message to process.
"""
pass
def start(self):
self._progress_tracker = progress_tracker.ProgressTracker(
message=' ',
detail_message_callback=self._get_status_string,
done_message_callback=self._get_done_string,
no_spacing=True)
self._progress_tracker.__enter__()
return self
def stop(self, exc_type, exc_val, exc_tb):
if self._progress_tracker:
self._progress_tracker.__exit__(exc_type, exc_val, exc_tb)
class _IntegerStatusTracker(_StatusTracker):
"""See super class. Tracks both file count and byte amount."""
def __init__(self):
super(_IntegerStatusTracker, self).__init__()
self._completed = 0
self._total_estimation = 0
def _get_status_string(self):
"""See super class."""
if self._total_estimation:
file_progress_string = '{}/{}'.format(self._completed,
self._total_estimation)
else:
file_progress_string = self._completed
return 'Completed {}\r'.format(file_progress_string)
def add_message(self, status_message):
"""See super class."""
if isinstance(status_message, thread_messages.WorkloadEstimatorMessage):
self._total_estimation += status_message.item_count
elif isinstance(status_message, thread_messages.IncrementProgressMessage):
self._completed += 1
class _FilesAndBytesStatusTracker(_StatusTracker, metrics_util.MetricsReporter):
"""See super class. Tracks both file count and byte amount."""
def __init__(self, manifest_path=None):
super(_FilesAndBytesStatusTracker, self).__init__()
# For displaying progress.
self._completed_files = 0
self._processed_bytes = 0
self._total_files_estimation = 0
self._total_bytes_estimation = 0
# For calculating average throughput.
self._first_operation_time = None
self._last_operation_time = None
self._total_processed_bytes = 0
# For calculating window throughput.
self._window_start_time = None
self._window_processed_bytes = 0
# String for on-the-fly display.
self._window_throughput = None
# For keeping track of progress on different files.
self._tracked_file_progress = {}
if manifest_path:
self._manifest_manager = manifest_util.ManifestManager(manifest_path)
else:
self._manifest_manager = None
def _get_status_string(self):
"""See super class."""
scaled_processed_bytes = scaled_integer.FormatBinaryNumber(
self._processed_bytes, decimal_places=1)
if self._total_files_estimation:
file_progress_string = '{}/{}'.format(self._completed_files,
self._total_files_estimation)
else:
file_progress_string = self._completed_files
if self._total_bytes_estimation:
scaled_total_bytes_estimation = scaled_integer.FormatBinaryNumber(
self._total_bytes_estimation, decimal_places=1)
bytes_progress_string = '{}/{}'.format(scaled_processed_bytes,
scaled_total_bytes_estimation)
else:
bytes_progress_string = scaled_processed_bytes
if self._window_throughput:
throughput_addendum_string = ' | ' + self._window_throughput
else:
throughput_addendum_string = ''
return 'Completed files {} | {}{}\r'.format(file_progress_string,
bytes_progress_string,
throughput_addendum_string)
def _update_throughput(self, status_message, processed_bytes):
"""Updates stats and recalculates throughput if past threshold."""
if self._first_operation_time is None:
self._first_operation_time = status_message.time
self._window_start_time = status_message.time
else:
self._last_operation_time = status_message.time
self._window_processed_bytes += processed_bytes
time_delta = status_message.time - self._window_start_time
if time_delta > _THROUGHPUT_WINDOW_THRESHOLD_SECONDS:
self._window_throughput = _get_formatted_throughput(
self._window_processed_bytes, time_delta)
self._window_start_time = status_message.time
self._window_processed_bytes = 0
def _add_to_workload_estimation(self, status_message):
"""Adds WorloadEstimatorMessage info to total workload estimation."""
self._total_files_estimation += status_message.item_count
self._total_bytes_estimation += status_message.size
def _add_progress(self, status_message):
"""Track progress of a multipart file operation."""
file_url_string = status_message.source_url.url_string
if file_url_string not in self._tracked_file_progress:
if status_message.total_components:
self._tracked_file_progress[file_url_string] = FileProgress(
component_count=status_message.total_components)
else:
self._tracked_file_progress[file_url_string] = FileProgress(
component_count=1)
if self._manifest_manager:
self._tracked_file_progress[file_url_string].start_time = (
datetime.datetime.fromtimestamp(status_message.time,
datetime.timezone.utc))
self._tracked_file_progress[file_url_string].total_bytes_copied = 0
component_tracker = self._tracked_file_progress[
file_url_string].component_progress
if status_message.component_number:
component_number = status_message.component_number
else:
component_number = 0
processed_component_bytes = (
status_message.current_byte - status_message.offset)
# status_message.current_byte includes bytes from past messages.
newly_processed_bytes = (
processed_component_bytes - component_tracker.get(component_number, 0))
self._processed_bytes += newly_processed_bytes
self._update_throughput(status_message, newly_processed_bytes)
if self._manifest_manager:
# Keep track of total bytes per file for writing to manifest.
self._tracked_file_progress[
file_url_string].total_bytes_copied += newly_processed_bytes
if status_message.error_occurred:
# If an error occurred, mark the file as failed.
self._tracked_file_progress[file_url_string].error_occurred = True
if processed_component_bytes == status_message.length:
# Operation complete.
component_tracker.pop(component_number, None)
if not component_tracker:
if not self._tracked_file_progress[file_url_string].error_occurred:
# Count as completed, if no error occurred.
self._completed_files += 1
if not self._manifest_manager:
# If managing manifest, _add_to_manifest clears items from tracking.
del self._tracked_file_progress[file_url_string]
else:
component_tracker[component_number] = processed_component_bytes
def _add_to_manifest(self, status_message):
"""Updates manifest file and pops file from tracking if needed."""
if not self._manifest_manager:
raise errors.Error(
'Received ManifestMessage but StatusTracker was not initialized with'
' manifest path.'
)
file_progress = self._tracked_file_progress.pop(
status_message.source_url.url_string, None)
self._manifest_manager.write_row(status_message, file_progress)
def add_message(self, status_message):
"""See super class."""
if isinstance(status_message, thread_messages.WorkloadEstimatorMessage):
self._add_to_workload_estimation(status_message)
elif isinstance(status_message, thread_messages.DetailedProgressMessage):
self._set_source_and_destination_schemes(status_message)
# If files start getting counted twice, see b/225182075.
self._add_progress(status_message)
elif isinstance(status_message, thread_messages.IncrementProgressMessage):
self._completed_files += 1
elif isinstance(status_message, thread_messages.ManifestMessage):
self._add_to_manifest(status_message)
def stop(self, exc_type, exc_val, exc_tb):
super(_FilesAndBytesStatusTracker, self).stop(exc_type, exc_val, exc_tb)
if (self._first_operation_time is not None and
self._last_operation_time is not None and
self._first_operation_time != self._last_operation_time):
time_delta = self._last_operation_time - self._first_operation_time
# Don't use get_done_string because it may cause line wrapping.
log.status.Print('\nAverage throughput: {}'.format(
_get_formatted_throughput(self._processed_bytes, time_delta)))
# Report event for analytics tracking, if enabled.
self._report_metrics(self._processed_bytes, time_delta,
self._completed_files)
def status_message_handler(task_status_queue, status_tracker):
"""Thread method for submiting items from queue to tracker for processing."""
unhandled_message_exists = False
while True:
status_message = task_status_queue.get()
if status_message == '_SHUTDOWN':
break
if status_tracker:
status_tracker.add_message(status_message)
else:
unhandled_message_exists = True
if unhandled_message_exists:
log.warning('Status message submitted to task_status_queue without a'
' manager to print it.')
def progress_manager(task_status_queue=None, progress_manager_args=None):
"""Factory function that returns a ProgressManager instance.
Args:
task_status_queue (multiprocessing.Queue|None): Tasks can submit their
progress messages here.
progress_manager_args (ProgressManagerArgs|None): Determines what type of
progress indicator to display.
Returns:
An instance of _ProgressManager or _NoOpProgressManager.
"""
if task_status_queue is not None:
return _ProgressManager(task_status_queue, progress_manager_args)
else:
return _NoOpProgressManager()
class _ProgressManager:
"""Context manager for processing and displaying progress completing command.
Ensure that this class is instantiated after all the child
processes (if any) are started to prevent deadlock.
"""
def __init__(self, task_status_queue, progress_manager_args=None):
"""Initializes context manager.
Args:
task_status_queue (multiprocessing.Queue): Tasks can submit their progress
messages here.
progress_manager_args (ProgressManagerArgs|None): Determines what type of
progress indicator to display.
"""
self._progress_manager_args = progress_manager_args
self._status_message_handler_thread = None
self._status_tracker = None
self._task_status_queue = task_status_queue
def __enter__(self):
if self._progress_manager_args:
if self._progress_manager_args.increment_type is IncrementType.INTEGER:
self._status_tracker = _IntegerStatusTracker()
elif (self._progress_manager_args.increment_type is
IncrementType.FILES_AND_BYTES):
self._status_tracker = _FilesAndBytesStatusTracker(
self._progress_manager_args.manifest_path)
self._status_message_handler_thread = threading.Thread(
target=status_message_handler,
args=(self._task_status_queue, self._status_tracker))
self._status_message_handler_thread.start()
if self._status_tracker:
self._status_tracker.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._task_status_queue.put('_SHUTDOWN')
self._status_message_handler_thread.join()
if self._status_tracker:
self._status_tracker.stop(exc_type, exc_val, exc_tb)
class _NoOpProgressManager:
"""Progress Manager that does not do anything.
Similar to contextlib.nullcontext, but it is available only for Python3.7+.
"""
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
del exc_type, exc_val, exc_tb # Unused.
pass

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for task execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import sys
from googlecloudsdk.command_lib.storage import errors
from googlecloudsdk.command_lib.storage import optimize_parameters_util
from googlecloudsdk.core import properties
def get_first_matching_message_payload(messages, topic):
"""Gets first item with matching topic from list of task output messages."""
for message in messages:
if topic is message.topic:
return message.payload
return None
def should_use_parallelism():
"""Checks execution settings to determine if parallelism should be used.
This function is called in some tasks to determine how they are being
executed, and should include as many of the relevant conditions as possible.
Returns:
True if parallel execution should be used, False otherwise.
"""
process_count = properties.VALUES.storage.process_count.GetInt()
thread_count = properties.VALUES.storage.thread_count.GetInt()
if process_count is None or thread_count is None:
# This can arise if optimize_parameters_util.detect_and_set_best_config has
# not been called before this method is called. This indicates that the user
# has not opted out of parallelism.
return optimize_parameters_util.DEFAULT_TO_PARALLELISM
return process_count > 1 or thread_count > 1
def require_python_3_5():
"""Task execution assumes Python versions >=3.5.
Raises:
InvalidPythonVersionError: if the Python version is not 3.5+.
"""
if sys.version_info.major < 3 or (sys.version_info.major == 3 and
sys.version_info.minor < 5):
raise errors.InvalidPythonVersionError(
'This functionality does not support Python {}.{}.{}. Please upgrade '
'to Python 3.5 or greater.'.format(
sys.version_info.major,
sys.version_info.minor,
sys.version_info.micro,
))