# -*- coding: utf-8 -*- # # Copyright 2020 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility for making API calls.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from apitools.base.py import exceptions as apitools_exceptions from apitools.base.py import http_wrapper from apitools.base.py import list_pager from googlecloudsdk.api_lib.artifacts import exceptions as ar_exceptions from googlecloudsdk.api_lib.cloudkms import iam as kms_iam from googlecloudsdk.api_lib.iam import util as iam_api from googlecloudsdk.api_lib.util import apis from googlecloudsdk.command_lib.iam import iam_util from googlecloudsdk.core import resources ARTIFACTREGISTRY_API_NAME = "artifactregistry" ARTIFACTREGISTRY_API_VERSION = "v1" STORAGE_API_NAME = "storage" STORAGE_API_VERSION = "v1" _GCR_PERMISSION = "storage.objects.list" CRYPTO_KEY_COLLECTION = "cloudkms.projects.locations.keyRings.cryptoKeys" def GetStorageClient(): return apis.GetClientInstance(STORAGE_API_NAME, STORAGE_API_VERSION) def GetStorageMessages(): return apis.GetMessagesModule(STORAGE_API_NAME, STORAGE_API_VERSION) def SkipRetryOn500Errors(response): """Wrap http_wrapper.CheckResponse to skip retry on 501.""" if response.status_code >= 500: raise apitools_exceptions.HttpError.FromResponse(response) return http_wrapper.CheckResponse(response) def GetClient(skip_activation_prompt=False): client = apis.GetClientInstance( ARTIFACTREGISTRY_API_NAME, ARTIFACTREGISTRY_API_VERSION, skip_activation_prompt=skip_activation_prompt, ) client.check_response_func = SkipRetryOn500Errors return client def GetMessages(): return apis.GetMessagesModule(ARTIFACTREGISTRY_API_NAME, ARTIFACTREGISTRY_API_VERSION) def GetClientV1beta2(): return apis.GetClientInstance(ARTIFACTREGISTRY_API_NAME, "v1beta2") def GetMessagesV1beta2(): return apis.GetMessagesModule(ARTIFACTREGISTRY_API_NAME, "v1beta2") def DeleteTag(client, messages, tag): """Deletes a tag by its name.""" delete_tag_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesTagsDeleteRequest( name=tag) err = client.projects_locations_repositories_packages_tags.Delete( delete_tag_req) if not isinstance(err, messages.Empty): raise ar_exceptions.ArtifactRegistryError( "Failed to delete tag {}: {}".format(tag, err)) def CreateDockerTag(client, messages, docker_tag, docker_version): """Creates a tag associated with the given docker version.""" tag = messages.Tag( name=docker_tag.GetTagName(), version=docker_version.GetVersionName()) create_tag_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesTagsCreateRequest( parent=docker_tag.GetPackageName(), tag=tag, tagId=docker_tag.tag) return client.projects_locations_repositories_packages_tags.Create( create_tag_req) def GetTag(client, messages, tag): """Gets a tag by its name.""" get_tag_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesTagsGetRequest( name=tag) return client.projects_locations_repositories_packages_tags.Get(get_tag_req) def DeleteVersion(client, messages, version): """Deletes a version by its name.""" delete_ver_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesVersionsDeleteRequest( name=version) return client.projects_locations_repositories_packages_versions.Delete( delete_ver_req) def DeletePackage(client, messages, package): """Deletes a package by its name.""" delete_pkg_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesDeleteRequest( name=package) return client.projects_locations_repositories_packages.Delete(delete_pkg_req) def GetVersion(client, messages, version): """Gets a version by its name.""" client = GetClient() messages = GetMessages() get_ver_req = ( messages .ArtifactregistryProjectsLocationsRepositoriesPackagesTagsGetRequest( name=version)) return client.projects_locations_repositories_packages_tags.Get(get_ver_req) def GetVersionFromTag(client, messages, tag): """Gets a version name by a tag name.""" get_tag_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesTagsGetRequest( name=tag) get_tag_res = client.projects_locations_repositories_packages_tags.Get( get_tag_req) if not get_tag_res.version or len(get_tag_res.version.split("/")) != 10: raise ar_exceptions.ArtifactRegistryError( "Internal error. Corrupted tag: {}".format(tag)) return get_tag_res.version.split("/")[-1] def ListTags(client, messages, package, page_size=None, server_filter=None): """Lists all tags under a package with the given package name.""" list_tags_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesTagsListRequest( parent=package, filter=server_filter) return list( list_pager.YieldFromList( client.projects_locations_repositories_packages_tags, list_tags_req, batch_size=page_size, batch_size_attribute="pageSize", field="tags")) def ListVersionTags(client, messages, package, version, page_size=None): """Lists tags associated with the given version.""" list_tags_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesTagsListRequest( parent=package, filter="version=\"{}\"".format(version)) return list( list_pager.YieldFromList( client.projects_locations_repositories_packages_tags, list_tags_req, batch_size=page_size, batch_size_attribute="pageSize", field="tags")) def ListPackages(client, messages, repo, page_size=None, order_by=None, limit=None, server_filter=None): """Lists all packages under a repository.""" list_pkgs_req = ( messages.ArtifactregistryProjectsLocationsRepositoriesPackagesListRequest( parent=repo, orderBy=order_by, filter=server_filter)) return list( list_pager.YieldFromList( client.projects_locations_repositories_packages, list_pkgs_req, limit=limit, batch_size=page_size, batch_size_attribute="pageSize", field="packages")) def ListVersions(client, messages, pkg, version_view=None, page_size=None, order_by=None, limit=None, server_filter=None): """Lists all versions under a package.""" page_limit = limit if limit is None or (page_size is not None and page_size < limit): page_limit = page_size list_vers_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesVersionsListRequest( parent=pkg, view=version_view, orderBy=order_by, filter=server_filter) return list( list_pager.YieldFromList( client.projects_locations_repositories_packages_versions, list_vers_req, limit=limit, batch_size=page_limit, batch_size_attribute="pageSize", field="versions")) def ListRepositories(project, page_size=None, order_by=None, server_filter=None): """Lists all repositories under a project.""" client = GetClient() messages = GetMessages() list_repos_req = ( messages.ArtifactregistryProjectsLocationsRepositoriesListRequest( parent=project, orderBy=order_by, filter=server_filter)) return list( list_pager.YieldFromList( client.projects_locations_repositories, list_repos_req, batch_size=page_size, batch_size_attribute="pageSize", field="repositories")) def ListFiles(client, messages, repo, server_filter=None, page_size=None, order_by=None): """Lists all files under a repository.""" list_files_req = ( messages.ArtifactregistryProjectsLocationsRepositoriesFilesListRequest( parent=repo, filter=server_filter, orderBy=order_by)) return list( list_pager.YieldFromList( client.projects_locations_repositories_files, list_files_req, batch_size=page_size, batch_size_attribute="pageSize", field="files")) def GetRepository(repo, skip_activation_prompt=False): """Gets the repository given its name.""" client = GetClient(skip_activation_prompt) messages = GetMessages() get_repo_req = messages.ArtifactregistryProjectsLocationsRepositoriesGetRequest( name=repo) get_repo_res = client.projects_locations_repositories.Get(get_repo_req) return get_repo_res def GetIamPolicy(repo_res): """Gets the IAM policy for the specified repository.""" client = GetClient() messages = GetMessages() get_iam_policy_req = messages.ArtifactregistryProjectsLocationsRepositoriesGetIamPolicyRequest( resource=repo_res) get_iam_policy_res = client.projects_locations_repositories.GetIamPolicy( get_iam_policy_req) return get_iam_policy_res def SetIamPolicy(repo_res, policy): """Sets the IAM policy for the specified repository.""" client = GetClient() ar_messages = GetMessages() # pylint: disable=line-too-long set_iam_policy_req = ar_messages.ArtifactregistryProjectsLocationsRepositoriesSetIamPolicyRequest( resource=repo_res, setIamPolicyRequest=ar_messages.SetIamPolicyRequest(policy=policy), ) return client.projects_locations_repositories.SetIamPolicy(set_iam_policy_req) def CreateRepository( project, location, repository, skip_activation_prompt=False ): """Creates the repository given its parent. Args: project: str: The project to create the repository in. location: str: The region to create the repository in. repository: messages.Repository to create. skip_activation_prompt: bool: If true, do not prompt for service activation Returns: The resulting operation from the create request. """ client = GetClient(skip_activation_prompt) messages = GetMessages() request = messages.ArtifactregistryProjectsLocationsRepositoriesCreateRequest( parent="projects/{}/locations/{}".format(project, location), repositoryId=repository.name.split("/")[-1], repository=repository) return client.projects_locations_repositories.Create(request) def GetPackage(package): """Gets the package given its name.""" client = GetClient() messages = GetMessages() get_package_req = messages.ArtifactregistryProjectsLocationsRepositoriesPackagesGetRequest( name=package) get_package_res = client.projects_locations_repositories_packages.Get( get_package_req) return get_package_res def ListLocations(project_id, page_size=None): """Lists all locations for a given project.""" client = GetClientV1beta2() messages = GetMessagesV1beta2() list_locs_req = messages.ArtifactregistryProjectsLocationsListRequest( name="projects/" + project_id) locations = list_pager.YieldFromList( client.projects_locations, list_locs_req, batch_size=page_size, batch_size_attribute="pageSize", field="locations") return sorted([loc.locationId for loc in locations]) def TestStorageIAMPermission(bucket, project): """Tests storage IAM permission for a given bucket for the user project.""" client = GetStorageClient() messages = GetStorageMessages() test_req = messages.StorageBucketsTestIamPermissionsRequest( bucket=bucket, permissions=_GCR_PERMISSION, userProject=project) return client.buckets.TestIamPermissions(test_req) def GetCryptoKeyPolicy(kms_key): """Gets the IAM policy for a given crypto key.""" crypto_key_ref = resources.REGISTRY.ParseRelativeName( relative_name=kms_key, collection=CRYPTO_KEY_COLLECTION) return kms_iam.GetCryptoKeyIamPolicy(crypto_key_ref) def AddCryptoKeyPermission(kms_key, service_account): """Adds Encrypter/Decrypter role to the given service account.""" crypto_key_ref = resources.REGISTRY.ParseRelativeName( relative_name=kms_key, collection=CRYPTO_KEY_COLLECTION) return kms_iam.AddPolicyBindingToCryptoKey( crypto_key_ref, service_account, "roles/cloudkms.cryptoKeyEncrypterDecrypter") def GetServiceAccount(service_account): """Gets the service account given its email.""" client, messages = iam_api.GetClientAndMessages() return client.projects_serviceAccounts.Get( messages.IamProjectsServiceAccountsGetRequest( name=iam_util.EmailToAccountResourceName(service_account))) def GetProjectSettings(project_id): client = GetClient() messages = GetMessages() get_settings_req = messages.ArtifactregistryProjectsGetProjectSettingsRequest( name="projects/" + project_id + "/projectSettings") return client.projects.GetProjectSettings(get_settings_req) def GetVPCSCConfig(project_id, location_id): """Gets VPC SC Config on the project and location.""" client = GetClient() messages = GetMessages() get_vpcsc_req = messages.ArtifactregistryProjectsLocationsGetVpcscConfigRequest( name="projects/" + project_id + "/locations/" + location_id + "/vpcscConfig") return client.projects_locations.GetVpcscConfig(get_vpcsc_req) def AllowVPCSCConfig(project_id, location_id): """Allows requests in Remote Repository inside VPC SC perimeter.""" client = GetClient() messages = GetMessages() vc = messages.VPCSCConfig( name="projects/" + project_id + "/locations/" + location_id + "/vpcscConfig", vpcscPolicy=messages.VPCSCConfig.VpcscPolicyValueValuesEnum.ALLOW) update_vpcsc_req = messages.ArtifactregistryProjectsLocationsUpdateVpcscConfigRequest( name="projects/" + project_id + "/locations/" + location_id + "/vpcscConfig", vPCSCConfig=vc) return client.projects_locations.UpdateVpcscConfig(update_vpcsc_req) def DenyVPCSCConfig(project_id, location_id): """Denies requests in Remote Repository inside VPC SC perimeter.""" client = GetClient() messages = GetMessages() vc = messages.VPCSCConfig( name="projects/" + project_id + "/locations/" + location_id + "/vpcscConfig", vpcscPolicy=messages.VPCSCConfig.VpcscPolicyValueValuesEnum.DENY) get_vpcsc_req = messages.ArtifactregistryProjectsLocationsUpdateVpcscConfigRequest( name="projects/" + project_id + "/locations/" + location_id + "/vpcscConfig", vPCSCConfig=vc) return client.projects_locations.UpdateVpcscConfig(get_vpcsc_req) def EnableUpgradeRedirection(project_id): messages = GetMessages() return SetUpgradeRedirectionState( project_id, messages.ProjectSettings.LegacyRedirectionStateValueValuesEnum .REDIRECTION_FROM_GCR_IO_ENABLED) def DisableUpgradeRedirection(project_id): messages = GetMessages() return SetUpgradeRedirectionState( project_id, messages.ProjectSettings.LegacyRedirectionStateValueValuesEnum .REDIRECTION_FROM_GCR_IO_DISABLED) def FinalizeUpgradeRedirection(project_id): messages = GetMessages() return SetUpgradeRedirectionState( project_id, messages.ProjectSettings.LegacyRedirectionStateValueValuesEnum .REDIRECTION_FROM_GCR_IO_FINALIZED) def SetUpgradeRedirectionState( project_id, redirection_state, pull_percent=None ): """Sets the upgrade redirection state for the supplied project.""" client = GetClient() messages = GetMessages() project_settings = messages.ProjectSettings( legacyRedirectionState=redirection_state) update_mask = "legacy_redirection_state" if pull_percent: project_settings.pullPercent = pull_percent update_settings_req = ( messages.ArtifactregistryProjectsUpdateProjectSettingsRequest( name="projects/" + project_id + "/projectSettings", projectSettings=project_settings, updateMask=update_mask, ) ) return client.projects.UpdateProjectSettings(update_settings_req) # TODO(b/339473586): If possible annotate list DockerImage output. def ListDockerImages(parent: str, page_size: int, limit: int): """Lists all docker images under a repository.""" client = GetClient() messages = GetMessages() list_images_req = messages.ArtifactregistryProjectsLocationsRepositoriesDockerImagesListRequest( parent=parent ) return list( list_pager.YieldFromList( client.projects_locations_repositories_dockerImages, list_images_req, batch_size=page_size, batch_size_attribute="pageSize", field="dockerImages", limit=limit, ) ) def CopyRepository(source_repo, dest_repo_name): """Copies a repository.""" client = GetClient() messages = GetMessages() req = messages.ArtifactregistryProjectsLocationsRepositoriesCopyRepositoryRequest( destinationRepository=dest_repo_name, copyRepositoryRequest=messages.CopyRepositoryRequest( sourceRepository=source_repo ), ) return client.projects_locations_repositories.CopyRepository(req) def ExportArtifact(version, tag, gcs_destination): """Exports an artifact by version or tag.""" client = GetClient() messages = GetMessages() if version: req = messages.ArtifactregistryProjectsLocationsRepositoriesExportArtifactRequest( repository=version.Parent().Parent().RelativeName(), exportArtifactRequest=messages.ExportArtifactRequest( gcsPath=gcs_destination, sourceVersion=version.RelativeName(), ), ) elif tag: req = messages.ArtifactregistryProjectsLocationsRepositoriesExportArtifactRequest( repository=tag.Parent().Parent().RelativeName(), exportArtifactRequest=messages.ExportArtifactRequest( gcsPath=gcs_destination, sourceTag=tag.RelativeName(), ), ) else: raise ValueError("Either version or tag must be specified.") return client.projects_locations_repositories.ExportArtifact(req)