feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Access approval requests API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
def Approve(name):
"""Approve an approval request."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsApprovalRequestsApproveRequest(
name=name)
return client.organizations_approvalRequests.Approve(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersApprovalRequestsApproveRequest(name=name)
return client.folders_approvalRequests.Approve(req)
req = msgs.AccessapprovalProjectsApprovalRequestsApproveRequest(name=name)
return client.projects_approvalRequests.Approve(req)
def Dismiss(name):
"""Dismiss an approval request."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsApprovalRequestsDismissRequest(
name=name)
return client.organizations_approvalRequests.Dismiss(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersApprovalRequestsDismissRequest(name=name)
return client.folders_approvalRequests.Dismiss(req)
req = msgs.AccessapprovalProjectsApprovalRequestsDismissRequest(name=name)
return client.projects_approvalRequests.Dismiss(req)
def Invalidate(name):
"""Invalidate an approval request."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsApprovalRequestsInvalidateRequest(
name=name)
return client.organizations_approvalRequests.Invalidate(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersApprovalRequestsInvalidateRequest(name=name)
return client.folders_approvalRequests.Invalidate(req)
req = msgs.AccessapprovalProjectsApprovalRequestsInvalidateRequest(name=name)
return client.projects_approvalRequests.Invalidate(req)
def Get(name):
"""Get an approval request by name."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsApprovalRequestsGetRequest(name=name)
return client.organizations_approvalRequests.Get(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersApprovalRequestsGetRequest(name=name)
return client.folders_approvalRequests.Get(req)
req = msgs.AccessapprovalProjectsApprovalRequestsGetRequest(name=name)
return client.projects_approvalRequests.Get(req)
def List(parent, filter=None):
"""List approval requests for the parent resource."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
req = None
svc = None
if 'organizations/' in parent:
req = msgs.AccessapprovalOrganizationsApprovalRequestsListRequest(
parent=parent)
svc = client.organizations_approvalRequests
elif 'folders/' in parent:
req = msgs.AccessapprovalFoldersApprovalRequestsListRequest(parent=parent)
svc = client.folders_approvalRequests
else:
req = msgs.AccessapprovalProjectsApprovalRequestsListRequest(parent=parent)
svc = client.projects_approvalRequests
if filter:
req.filter = filter
else:
req.filter = 'PENDING'
return list_pager.YieldFromList(
svc, req, field='approvalRequests', batch_size_attribute='pageSize')

View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Access approval service account API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
def Get(name):
"""Get the access approval service account for a resource."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsGetServiceAccountRequest(name=name)
return client.organizations.GetServiceAccount(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersGetServiceAccountRequest(name=name)
return client.folders.GetServiceAccount(req)
req = msgs.AccessapprovalProjectsGetServiceAccountRequest(name=name)
return client.projects.GetServiceAccount(req)

View File

@@ -0,0 +1,136 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Access approval settings API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
def Delete(name):
"""Delete the access approval settings for a resource."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsDeleteAccessApprovalSettingsRequest(
name=name
)
return client.organizations.DeleteAccessApprovalSettings(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersDeleteAccessApprovalSettingsRequest(
name=name
)
return client.folders.DeleteAccessApprovalSettings(req)
req = msgs.AccessapprovalProjectsDeleteAccessApprovalSettingsRequest(
name=name
)
return client.projects.DeleteAccessApprovalSettings(req)
def Get(name):
"""Get the access approval settings for a resource."""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsGetAccessApprovalSettingsRequest(
name=name
)
return client.organizations.GetAccessApprovalSettings(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersGetAccessApprovalSettingsRequest(name=name)
return client.folders.GetAccessApprovalSettings(req)
req = msgs.AccessapprovalProjectsGetAccessApprovalSettingsRequest(name=name)
return client.projects.GetAccessApprovalSettings(req)
def Update(
name,
notification_emails,
enrolled_services,
active_key_version,
preferred_request_expiration_days,
prefer_no_broad_approval_requests,
notification_pubsub_topic,
request_scope_max_width_preference,
require_customer_visible_justification,
approval_policy,
update_mask,
):
"""Update the access approval settings for a resource.
Args:
name: the settings resource name (e.g. projects/123/accessApprovalSettings)
notification_emails: list of email addresses
enrolled_services: list of services
active_key_version: KMS signing key version resource name
preferred_request_expiration_days: the default expiration time for approval
requests
prefer_no_broad_approval_requests: communicates the preference to Google
personnel to request access with as targeted a resource scope as possible
notification_pubsub_topic: A pubsub topic to which notifications relating to
approval requests should be sent
request_scope_max_width_preference: specifies broadest scope of access for
access requests without a specific method
require_customer_visible_justification: to configure if a customer visible
justification (i.e. Vector Case) is required for a Googler to create an
Access Ticket to send to the customer when attempting to access customer
resources.
approval_policy: the policy for approving requests
update_mask: which fields to update
Returns:
updated settings
"""
client = apis.GetClientInstance('accessapproval', 'v1')
msgs = apis.GetMessagesModule('accessapproval', 'v1')
settings = None
services_protos = [
msgs.EnrolledService(cloudProduct=s) for s in enrolled_services
]
settings = msgs.AccessApprovalSettings(
name=name,
enrolledServices=services_protos,
notificationEmails=notification_emails,
activeKeyVersion=active_key_version,
preferredRequestExpirationDays=preferred_request_expiration_days,
preferNoBroadApprovalRequests=prefer_no_broad_approval_requests,
notificationPubsubTopic=notification_pubsub_topic,
requestScopeMaxWidthPreference=request_scope_max_width_preference,
requireCustomerVisibleJustification=require_customer_visible_justification,
approvalPolicy=approval_policy,
)
if 'organizations/' in name:
req = msgs.AccessapprovalOrganizationsUpdateAccessApprovalSettingsRequest(
name=name, accessApprovalSettings=settings, updateMask=update_mask
)
return client.organizations.UpdateAccessApprovalSettings(req)
if 'folders/' in name:
req = msgs.AccessapprovalFoldersUpdateAccessApprovalSettingsRequest(
name=name, accessApprovalSettings=settings, updateMask=update_mask
)
return client.folders.UpdateAccessApprovalSettings(req)
req = msgs.AccessapprovalProjectsUpdateAccessApprovalSettingsRequest(
name=name, accessApprovalSettings=settings, updateMask=update_mask
)
return client.projects.UpdateAccessApprovalSettings(req)

View File

@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified diff resource printer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import difflib
import io
import re
from googlecloudsdk.core import exceptions
from googlecloudsdk.core.resource import resource_printer_base
from googlecloudsdk.core.resource import resource_projection_spec
from googlecloudsdk.core.resource import resource_projector
from googlecloudsdk.core.resource import resource_transform
from googlecloudsdk.core.resource import yaml_printer
class ACMDiffPrinter(resource_printer_base.ResourcePrinter):
"""A printer for an ndiff of the first two projection columns.
A unified diff of the first two projection columns.
Printer attributes:
format: The format of the diffed resources. Each resource is converted
to this format and the diff of the converted resources is displayed.
The default is 'yaml'.
"""
def __init__(self, *args, **kwargs):
super(ACMDiffPrinter, self).__init__(
*args, by_columns=True, non_empty_projection_required=True, **kwargs)
self._print_format = self.attributes.get('format', 'yaml')
def _Diff(self, old, new):
"""Prints a modified ndiff of formatter output for old and new.
IngressPolicies:
ingressFrom:
sources:
accessLevel: accessPolicies/123456789/accessLevels/my_level
-resource: projects/123456789012
+resource: projects/234567890123
EgressPolicies:
+egressTo:
+operations:
+actions:
+action: method_for_all
+actionType: METHOD
+serviceName: chemisttest.googleapis.com
+resources:
+projects/345678901234
Args:
old: The old original resource.
new: The new changed resource.
"""
# Fill a buffer with the object as rendered originally.
buf_old = io.StringIO()
printer = self.Printer(self._print_format, out=buf_old)
printer.PrintSingleRecord(old)
# Fill a buffer with the object as rendered after the change.
buf_new = io.StringIO()
printer = self.Printer(self._print_format, out=buf_new)
printer.PrintSingleRecord(new)
lines_old = ''
lines_new = ''
# Send these two buffers to the ndiff() function for printing.
if old is not None:
lines_old = self._FormatYamlPrinterLinesForDryRunDescribe(
buf_old.getvalue().split('\n'))
if new is not None:
lines_new = self._FormatYamlPrinterLinesForDryRunDescribe(
buf_new.getvalue().split('\n'))
lines_diff = difflib.ndiff(lines_old, lines_new)
empty_line_pattern = re.compile(r'^\s*$')
empty_config_pattern = re.compile(r'^(\+|-)\s+\{\}$')
for line in lines_diff:
# We want to show the entire contents of resource, but without the
# additional information added by ndiff, which always leads with '?'. We
# also don't want to show empty lines produced from comparing unset
# fields, as well as lines produced from comparing empty messages, which
# will look like '+ {}' or '- {}'.
if line and line[0] != '?' and not empty_line_pattern.match(
line) and not empty_config_pattern.match(line):
print(line)
def _AddRecord(self, record, delimit=False):
"""Immediately prints the first two columns of record as a unified diff.
Records with less than 2 columns are silently ignored.
Args:
record: A JSON-serializable object.
delimit: Prints resource delimiters if True.
"""
title = self.attributes.get('title')
if title:
self._out.Print(title)
self._title = None
if len(record) > 1:
self._Diff(record[0], record[1])
def _FormatYamlPrinterLinesForDryRunDescribe(self, lines):
"""Tweak yaml printer formatted resources for ACM's dry run describe output.
Args:
lines: yaml printer formatted strings
Returns:
lines with no '-' prefix for yaml array elements.
"""
return [line.replace('-', ' ', 1) for line in lines]
class Error(exceptions.Error):
"""Exceptions for this module."""
class UnknownFormatError(Error):
"""Unknown format name exception."""
_FORMATTERS = {
'default': yaml_printer.YamlPrinter,
'diff': ACMDiffPrinter,
'yaml': yaml_printer.YamlPrinter,
}
def Print(resources, print_format, out=None, defaults=None, single=False):
"""Prints the given resources.
Args:
resources: A singleton or list of JSON-serializable Python objects.
print_format: The _FORMATTER name with optional projection expression.
out: Output stream, log.out if None.
defaults: Optional resource_projection_spec.ProjectionSpec defaults.
single: If True then resources is a single item and not a list. For example,
use this to print a single object as JSON.
"""
printer = Printer(print_format, out=out, defaults=defaults)
# None means the printer is disabled.
if printer:
printer.Print(resources, single)
def Printer(print_format, out=None, defaults=None, console_attr=None):
"""Returns a resource printer given a format string.
Args:
print_format: The _FORMATTERS name with optional attributes and projection.
out: Output stream, log.out if None.
defaults: Optional resource_projection_spec.ProjectionSpec defaults.
console_attr: The console attributes for the output stream. Ignored by some
printers. If None then printers that require it will initialize it to
match out.
Raises:
UnknownFormatError: The print_format is invalid.
Returns:
An initialized ResourcePrinter class or None if printing is disabled.
"""
projector = resource_projector.Compile(
expression=print_format,
defaults=resource_projection_spec.ProjectionSpec(
defaults=defaults, symbols=resource_transform.GetTransforms()))
printer_name = projector.Projection().Name()
if not printer_name:
# Do not print, do not consume resources.
return None
try:
printer_class = _FORMATTERS[printer_name]
except KeyError:
raise UnknownFormatError("""\
Format for acm_printer must be one of {0}; received [{1}].
""".format(', '.join(SupportedFormats()), printer_name))
printer = printer_class(
out=out,
name=printer_name,
printer=Printer,
projector=projector,
console_attr=console_attr)
return printer
def SupportedFormats():
"""Returns a sorted list of supported format names."""
return sorted(_FORMATTERS)

View File

@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for Authorized Orgs Desc."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.accesscontextmanager import util
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import log
from googlecloudsdk.core import resources as core_resources
def _SetIfNotNone(field_name, field_value, obj, update_mask):
"""Sets specified field to the provided value and adds it to update mask.
Args:
field_name: The name of the field to set the value of.
field_value: The value to set the field to. If it is None, the field will
NOT be set.
obj: The object on which the value is to be set.
update_mask: The update mask to add this field to.
Returns:
True if the field was set and False otherwise.
"""
if field_value is not None:
setattr(obj, field_name, field_value)
update_mask.append(field_name)
return True
return False
class Client(object):
"""High-level API client for Authorized Orgs."""
def __init__(self, client=None, messages=None, version='v1'):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE
def Get(self, authorized_orgs_desc_ref):
return self.client.accessPolicies_authorizedOrgsDescs.Get(
self.messages
.AccesscontextmanagerAccessPoliciesAuthorizedOrgsDescsGetRequest(
name=authorized_orgs_desc_ref.RelativeName()))
def List(self, policy_ref, limit=None):
req = self.messages.AccesscontextmanagerAccessPoliciesAuthorizedOrgsDescsListRequest(
parent=policy_ref.RelativeName())
return list_pager.YieldFromList(
self.client.accessPolicies_authorizedOrgsDescs,
req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=None,
field='authorizedOrgsDescs')
def _ApplyPatch(self, authorized_orgs_desc_ref, authorized_orgs_desc,
update_mask):
"""Applies a PATCH to the provided Authorized Orgs Desc."""
m = self.messages
request_type = (
m.AccesscontextmanagerAccessPoliciesAuthorizedOrgsDescsPatchRequest)
request = request_type(
authorizedOrgsDesc=authorized_orgs_desc,
name=authorized_orgs_desc_ref.RelativeName(),
updateMask=','.join(update_mask),
)
operation = self.client.accessPolicies_authorizedOrgsDescs.Patch(request)
poller = util.OperationPoller(
self.client.accessPolicies_authorizedOrgsDescs, self.client.operations,
authorized_orgs_desc_ref)
operation_ref = core_resources.REGISTRY.Parse(
operation.name, collection='accesscontextmanager.operations')
return waiter.WaitFor(
poller, operation_ref,
'Waiting for PATCH operation [{}]'.format(operation_ref.Name()))
def Patch(self, authorized_orgs_desc_ref, orgs=None):
"""Patch an authorized orgs desc.
Args:
authorized_orgs_desc_ref: AuthorizedOrgsDesc, reference to the
authorizedOrgsDesc to patch
orgs: list of str, the names of orgs ( 'organizations/...') or None if not
updating.
Returns:
AuthorizedOrgsDesc, the updated Authorized Orgs Desc.
"""
m = self.messages
authorized_orgs_desc = m.AuthorizedOrgsDesc()
update_mask = []
_SetIfNotNone('orgs', orgs, authorized_orgs_desc, update_mask)
# No update mask implies no fields were actually edited, so this is a no-op.
if not update_mask:
log.warning(
'The update specified results in an identical resource. Skipping request.'
)
return authorized_orgs_desc
return self._ApplyPatch(authorized_orgs_desc_ref, authorized_orgs_desc,
update_mask)

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for access context manager cloud-bindings."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.accesscontextmanager import util
class Client(object):
"""Client for Access Context Manager Access cloud-bindings service."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE

View File

@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for access context manager levels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.accesscontextmanager import util
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import resources
class Client(object):
def __init__(self, client=None, messages=None, version=None):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE
def List(self, policy_ref, limit=None):
req = (
self.messages.AccesscontextmanagerAccessPoliciesAccessLevelsListRequest(
parent=policy_ref.RelativeName()
)
)
return list_pager.YieldFromList(
self.client.accessPolicies_accessLevels,
req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=None,
field='accessLevels',
)
def Patch(
self,
level_ref,
description=None,
title=None,
basic_level_combine_function=None,
basic_level_conditions=None,
custom_level_expr=None,
):
"""Patch an access level.
Args:
level_ref: resources.Resource, reference to the level to patch
description: str, description of the level or None if not updating
title: str, title of the level or None if not updating
basic_level_combine_function: ZoneTypeValueValuesEnum, combine function
enum value of the level or None if not updating
basic_level_conditions: list of Condition, the conditions for a basic
level or None if not updating
custom_level_expr: the expression of the Custom level, or none if not
updating.
Returns:
AccessLevel, the updated access level
"""
level = self.messages.AccessLevel()
update_mask = []
if description is not None:
update_mask.append('description')
level.description = description
if title is not None:
update_mask.append('title')
level.title = title
if basic_level_combine_function is not None:
update_mask.append('basic.combiningFunction')
level.basic = level.basic or self.messages.BasicLevel()
level.basic.combiningFunction = basic_level_combine_function
if basic_level_conditions is not None:
update_mask.append('basic.conditions')
level.basic = level.basic or self.messages.BasicLevel()
level.basic.conditions = basic_level_conditions
if custom_level_expr is not None:
update_mask.append('custom')
level.custom = level.custom or self.messages.CustomLevel()
level.custom.expr = custom_level_expr
update_mask.sort() # For ease-of-testing
m = self.messages
request_type = m.AccesscontextmanagerAccessPoliciesAccessLevelsPatchRequest
request = request_type(
accessLevel=level,
name=level_ref.RelativeName(),
updateMask=','.join(update_mask),
)
operation = self.client.accessPolicies_accessLevels.Patch(request)
poller = util.OperationPoller(self.client.accessPolicies_accessLevels,
self.client.operations, level_ref)
operation_ref = resources.REGISTRY.Parse(
operation.name, collection='accesscontextmanager.operations')
return waiter.WaitFor(
poller, operation_ref,
'Waiting for PATCH operation [{}]'.format(operation_ref.Name()))

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for access context manager policies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.accesscontextmanager import util
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import resources
class Client(object):
"""Client for Access Context Manager Access Policies service."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE
def List(self, organization_ref, limit=None):
req = self.messages.AccesscontextmanagerAccessPoliciesListRequest(
parent=organization_ref.RelativeName())
return list_pager.YieldFromList(
self.client.accessPolicies, req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=None,
field='accessPolicies')
def Patch(self, policy_ref, title=None):
"""Patch an access policy.
Args:
policy_ref: resources.Resource, reference to the policy to patch
title: str, title of the policy or None if not updating
Returns:
AccessPolicy, the updated access policy
"""
policy = self.messages.AccessPolicy()
update_mask = []
if title is not None:
update_mask.append('title')
policy.title = title
update_mask.sort() # For ease-of-testing
m = self.messages
request_type = m.AccesscontextmanagerAccessPoliciesPatchRequest
request = request_type(
accessPolicy=policy,
name=policy_ref.RelativeName(),
updateMask=','.join(update_mask),
)
operation = self.client.accessPolicies.Patch(request)
poller = waiter.CloudOperationPoller(self.client.accessPolicies,
self.client.operations)
poller = util.OperationPoller(
self.client.accessPolicies, self.client.operations, policy_ref)
operation_ref = resources.REGISTRY.Parse(
operation.name, collection='accesscontextmanager.operations')
return waiter.WaitFor(
poller, operation_ref,
'Waiting for PATCH operation [{}]'.format(operation_ref.Name()))

View File

@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for Supported Permissions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.accesscontextmanager import util
class Client(object):
"""High-level API client for Supported Permissions."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE
def List(self, page_size=100, limit=None):
"""Make API call to list VPC Service Controls supported permissions.
Args:
page_size: The page size to list.
limit: The maximum number of permissions to display.
Returns:
The list of VPC Service Controls supported permissions.
"""
req = self.messages.AccesscontextmanagerPermissionsListRequest()
return list_pager.YieldFromList(
self.client.permissions,
req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=page_size,
field='supportedPermissions',
)

View File

@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for Supported Services."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.accesscontextmanager import util
class Client(object):
"""High-level API client for Supported Services."""
def __init__(self, client=None, messages=None, version='v1'):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE
def Get(self, supported_services_ref):
return self.client.services.Get(
self.messages.AccesscontextmanagerServicesGetRequest(
name=supported_services_ref.RelativeName()
)
)
def List(self, page_size=200, limit=None):
"""Make API call to list VPC Service Controls supported services.
Args:
page_size: The page size to list.
limit: The maximum number of services to display.
Returns:
The list of VPC Service Controls supported services
"""
req = self.messages.AccesscontextmanagerServicesListRequest()
return list_pager.YieldFromList(
self.client.services,
req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=page_size,
field='supportedServices',
)

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API utilities for access context manager."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
_API_NAME = 'accesscontextmanager'
def _GetDefaultVersion():
return apis.ResolveVersion(_API_NAME)
def GetMessages(version=None):
version = version or _GetDefaultVersion()
return apis.GetMessagesModule(_API_NAME, version)
def GetClient(version=None):
version = version or _GetDefaultVersion()
return apis.GetClientInstance(_API_NAME, version)
class OperationPoller(waiter.CloudOperationPoller):
def __init__(self, result_service, operation_service, resource_ref):
super(OperationPoller, self).__init__(result_service, operation_service)
self.resource_ref = resource_ref
def GetResult(self, operation):
del operation # Unused in GetResult
request_type = self.result_service.GetRequestType('Get')
return self.result_service.Get(request_type(
name=self.resource_ref.RelativeName()))

View File

@@ -0,0 +1,353 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API library for VPC Service Controls Service Perimeters."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.accesscontextmanager import util
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import log
from googlecloudsdk.core import resources as core_resources
import six
def _SetIfNotNone(field_name, field_value, obj, update_mask):
"""Sets specified field to the provided value and adds it to update mask.
Args:
field_name: The name of the field to set the value of.
field_value: The value to set the field to. If it is None, the field will
NOT be set.
obj: The object on which the value is to be set.
update_mask: The update mask to add this field to.
Returns:
True if the field was set and False otherwise.
"""
if field_value is not None:
setattr(obj, field_name, field_value)
update_mask.append(field_name)
return True
return False
def _CreateServicePerimeterConfig(messages,
mask_prefix,
resources,
restricted_services,
levels,
vpc_allowed_services,
enable_vpc_accessible_services,
vpc_yaml_flag_used,
vpc_accessible_services_config=None,
ingress_policies=None,
egress_policies=None):
"""Returns a ServicePerimeterConfig and its update mask."""
config = messages.ServicePerimeterConfig()
mask = []
_SetIfNotNone('resources', resources, config, mask)
_SetIfNotNone('restrictedServices', restricted_services, config, mask)
_SetIfNotNone('ingressPolicies', ingress_policies, config, mask)
_SetIfNotNone('egressPolicies', egress_policies, config, mask)
if levels is not None:
mask.append('accessLevels')
level_names = []
for l in levels:
# If the caller supplies the levels as strings already, use them directly.
if isinstance(l, six.string_types):
level_names.append(l)
else:
# Otherwise, the caller needs to supply resource objects for Access
# Levels, and we extract the level name from those.
level_names.append(l.RelativeName())
config.accessLevels = level_names
if vpc_yaml_flag_used:
mask.append('vpcAccessibleServices')
config.vpcAccessibleServices = vpc_accessible_services_config
elif (
enable_vpc_accessible_services is not None
or vpc_allowed_services is not None
):
service_filter = messages.VpcAccessibleServices()
service_filter_mask = []
_SetIfNotNone('allowedServices', vpc_allowed_services, service_filter,
service_filter_mask)
_SetIfNotNone('enableRestriction', enable_vpc_accessible_services,
service_filter, service_filter_mask)
config.vpcAccessibleServices = service_filter
mask.extend(['vpcAccessibleServices.' + m for m in service_filter_mask])
if not mask:
return None, []
return config, ['{}.{}'.format(mask_prefix, item) for item in mask]
class Client(object):
"""High-level API client for VPC Service Controls Service Perimeters."""
def __init__(self, client=None, messages=None, version='v1'):
self.client = client or util.GetClient(version=version)
self.messages = messages or self.client.MESSAGES_MODULE
def Get(self, zone_ref):
return self.client.accessPolicies_servicePerimeters.Get(
self.messages
.AccesscontextmanagerAccessPoliciesServicePerimetersGetRequest(
name=zone_ref.RelativeName()))
def List(self, policy_ref, limit=None):
req = self.messages.AccesscontextmanagerAccessPoliciesServicePerimetersListRequest(
parent=policy_ref.RelativeName())
return list_pager.YieldFromList(
self.client.accessPolicies_servicePerimeters,
req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=None,
field='servicePerimeters')
def Commit(self, policy_ref, etag):
commit_req = self.messages.CommitServicePerimetersRequest(etag=etag)
req = self.messages.AccesscontextmanagerAccessPoliciesServicePerimetersCommitRequest(
parent=policy_ref.RelativeName(),
commitServicePerimetersRequest=commit_req)
operation = self.client.accessPolicies_servicePerimeters.Commit(req)
poller = waiter.CloudOperationPollerNoResources(self.client.operations)
operation_ref = core_resources.REGISTRY.Parse(
operation.name, collection='accesscontextmanager.operations')
return waiter.WaitFor(
poller, operation_ref,
'Waiting for COMMIT operation [{}]'.format(operation_ref.Name()))
def _ApplyPatch(self, perimeter_ref, perimeter, update_mask):
"""Applies a PATCH to the provided Service Perimeter."""
m = self.messages
update_mask = sorted(update_mask) # For ease-of-testing
request_type = (
m.AccesscontextmanagerAccessPoliciesServicePerimetersPatchRequest)
request = request_type(
servicePerimeter=perimeter,
name=perimeter_ref.RelativeName(),
updateMask=','.join(update_mask),
)
operation = self.client.accessPolicies_servicePerimeters.Patch(request)
poller = util.OperationPoller(self.client.accessPolicies_servicePerimeters,
self.client.operations, perimeter_ref)
operation_ref = core_resources.REGISTRY.Parse(
operation.name, collection='accesscontextmanager.operations')
return waiter.WaitFor(
poller, operation_ref,
'Waiting for PATCH operation [{}]'.format(operation_ref.Name()))
def Patch(
self,
perimeter_ref,
description=None,
title=None,
perimeter_type=None,
resources=None,
restricted_services=None,
levels=None,
vpc_allowed_services=None,
enable_vpc_accessible_services=None,
vpc_yaml_flag_used=False,
vpc_accessible_services_config=None,
ingress_policies=None,
egress_policies=None,
etag=None,
):
"""Patch a service perimeter.
Args:
perimeter_ref: resources.Resource, reference to the perimeter to patch
description: str, description of the zone or None if not updating
title: str, title of the zone or None if not updating
perimeter_type: PerimeterTypeValueValuesEnum type enum value for the level
or None if not updating
resources: list of str, the names of resources (for now, just
'projects/...') in the zone or None if not updating.
restricted_services: list of str, the names of services
('example.googleapis.com') that *are* restricted by the access zone or
None if not updating.
levels: list of Resource, the access levels (in the same policy) that must
be satisfied for calls into this zone or None if not updating.
vpc_allowed_services: list of str, the names of services
('example.googleapis.com') that *are* allowed to be made within the
access zone, or None if not updating.
enable_vpc_accessible_services: bool, whether to restrict the set of APIs
callable within the access zone, or None if not updating.
vpc_yaml_flag_used: bool, whether the vpc yaml flag was used.
vpc_accessible_services_config: VpcAccessibleServices, or None if not
updating.
ingress_policies: list of IngressPolicy, or None if not updating.
egress_policies: list of EgressPolicy, or None if not updating.
etag: str, the optional etag for the version of the Perimeter that
this operation is to be performed on.
Returns:
ServicePerimeter, the updated Service Perimeter.
"""
m = self.messages
perimeter = m.ServicePerimeter()
update_mask = []
_SetIfNotNone('title', title, perimeter, update_mask)
_SetIfNotNone('description', description, perimeter, update_mask)
_SetIfNotNone('perimeterType', perimeter_type, perimeter, update_mask)
_SetIfNotNone('etag', etag, perimeter, update_mask)
config, config_mask_additions = _CreateServicePerimeterConfig(
messages=m,
mask_prefix='status',
resources=resources,
restricted_services=restricted_services,
levels=levels,
vpc_allowed_services=vpc_allowed_services,
enable_vpc_accessible_services=enable_vpc_accessible_services,
vpc_yaml_flag_used=vpc_yaml_flag_used,
vpc_accessible_services_config=vpc_accessible_services_config,
ingress_policies=ingress_policies,
egress_policies=egress_policies)
perimeter.status = config
update_mask.extend(config_mask_additions)
# No update mask implies no fields were actually edited, so this is a no-op.
if not update_mask:
log.warning(
'The update specified results in an identical resource. Skipping request.'
)
return perimeter
return self._ApplyPatch(perimeter_ref, perimeter, update_mask)
def PatchDryRunConfig(
self,
perimeter_ref,
description=None,
title=None,
perimeter_type=None,
resources=None,
restricted_services=None,
levels=None,
vpc_allowed_services=None,
enable_vpc_accessible_services=None,
vpc_yaml_flag_used=False,
vpc_accessible_services_config=None,
ingress_policies=None,
egress_policies=None,
etag=None,
):
"""Patch the dry-run config (spec) for a Service Perimeter.
Args:
perimeter_ref: resources.Resource, reference to the perimeter to patch
description: str, description of the zone or None if not updating
title: str, title of the zone or None if not updating
perimeter_type: PerimeterTypeValueValuesEnum type enum value for the level
or None if not updating
resources: list of str, the names of resources (for now, just
'projects/...') in the zone or None if not updating.
restricted_services: list of str, the names of services
('example.googleapis.com') that *are* restricted by the access zone or
None if not updating.
levels: list of Resource, the access levels (in the same policy) that must
be satisfied for calls into this zone or None if not updating.
vpc_allowed_services: list of str, the names of services
('example.googleapis.com') that *are* allowed to be made within the
access zone, or None if not updating.
enable_vpc_accessible_services: bool, whether to restrict the set of APIs
callable within the access zone, or None if not updating.
vpc_yaml_flag_used: bool, whether the vpc yaml flag was used.
vpc_accessible_services_config: VpcAccessibleServices, or None if not
updating.
ingress_policies: list of IngressPolicy, or None if not updating.
egress_policies: list of EgressPolicy, or None if not updating.
etag: str, the optional etag for the version of the Perimeter that
this operation is to be performed on.
Returns:
ServicePerimeter, the updated Service Perimeter.
"""
m = self.messages
perimeter = m.ServicePerimeter()
update_mask = []
if _SetIfNotNone('title', title, perimeter, update_mask):
perimeter.name = perimeter_ref.RelativeName() # Necessary for upsert.
update_mask.append('name')
_SetIfNotNone('description', description, perimeter, update_mask)
_SetIfNotNone('perimeterType', perimeter_type, perimeter, update_mask)
_SetIfNotNone('etag', etag, perimeter, update_mask)
config, config_mask_additions = _CreateServicePerimeterConfig(
messages=m,
mask_prefix='spec',
resources=resources,
restricted_services=restricted_services,
levels=levels,
vpc_allowed_services=vpc_allowed_services,
enable_vpc_accessible_services=enable_vpc_accessible_services,
vpc_yaml_flag_used=vpc_yaml_flag_used,
vpc_accessible_services_config=vpc_accessible_services_config,
ingress_policies=ingress_policies,
egress_policies=egress_policies)
perimeter.spec = config
update_mask.extend(config_mask_additions)
perimeter.useExplicitDryRunSpec = True
update_mask.append('useExplicitDryRunSpec')
return self._ApplyPatch(perimeter_ref, perimeter, update_mask)
def EnforceDryRunConfig(self, perimeter_ref):
"""Promotes a Service Perimeter's dry-run config to enforcement config.
Args:
perimeter_ref: resources.Resource, reference to the perimeter to patch
Returns:
ServicePerimeter, the updated Service Perimeter.
"""
original_perimeter = self.Get(perimeter_ref)
m = self.messages
perimeter = m.ServicePerimeter()
update_mask = ['status', 'spec', 'useExplicitDryRunSpec']
perimeter.status = original_perimeter.spec
perimeter.spec = None
perimeter.useExplicitDryRunSpec = False
return self._ApplyPatch(perimeter_ref, perimeter, update_mask)
def UnsetSpec(self, perimeter_ref, use_explicit_dry_run_spec):
"""Unsets the spec for a Service Perimeter.
Args:
perimeter_ref: resources.Resource, reference to the perimeter to patch.
use_explicit_dry_run_spec: The value to use for the perimeter field of the
same name.
Returns:
ServicePerimeter, the updated Service Perimeter.
"""
perimeter = self.messages.ServicePerimeter()
perimeter.useExplicitDryRunSpec = use_explicit_dry_run_spec
perimeter.spec = None
update_mask = ['spec', 'useExplicitDryRunSpec']
return self._ApplyPatch(perimeter_ref, perimeter, update_mask)

View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API utilities for `gcloud active-directory` commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
API_VERSION_FOR_TRACK = {
base.ReleaseTrack.BETA: 'v1beta1',
base.ReleaseTrack.ALPHA: 'v1alpha1'
}
def Client(api_version):
"""Creates a managedidentities client."""
return apis.GetClientInstance('managedidentities', api_version)
def Messages(api_version):
"""Messages for the managedidentities API."""
return apis.GetMessagesModule('managedidentities', api_version)

View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for user-visible error exceptions to raise in the CLI."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import exceptions
class ActiveDirectoryError(exceptions.Error):
"""Generic managedidentities error."""

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying custom jobs in AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core.console import console_io
class CustomJobsClient(object):
"""Client used for interacting with CustomJob endpoint."""
def __init__(self, version=constants.GA_VERSION):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._version = version
self._service = client.projects_locations_customJobs
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def GetMessage(self, message_name):
"""Returns the API message class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def CustomJobMessage(self):
"""Retures the CustomJob resource message."""
return self.GetMessage('CustomJob')
def Create(self,
parent,
job_spec,
display_name=None,
kms_key_name=None,
labels=None):
"""Constructs a request and sends it to the endpoint to create a custom job instance.
Args:
parent: str, The project resource path of the custom job to create.
job_spec: The CustomJobSpec message instance for the job creation request.
display_name: str, The display name of the custom job to create.
kms_key_name: A customer-managed encryption key to use for the custom job.
labels: LabelValues, map-like user-defined metadata to organize the custom
jobs.
Returns:
A CustomJob message instance created.
"""
custom_job = self.CustomJobMessage()(
displayName=display_name, jobSpec=job_spec)
if kms_key_name is not None:
custom_job.encryptionSpec = self.GetMessage('EncryptionSpec')(
kmsKeyName=kms_key_name)
if labels:
custom_job.labels = labels
if self._version == constants.BETA_VERSION:
return self._service.Create(
self._messages.AiplatformProjectsLocationsCustomJobsCreateRequest(
parent=parent, googleCloudAiplatformV1beta1CustomJob=custom_job))
else:
return self._service.Create(
self._messages.AiplatformProjectsLocationsCustomJobsCreateRequest(
parent=parent, googleCloudAiplatformV1CustomJob=custom_job))
def List(self, limit=None, region=None):
return list_pager.YieldFromList(
self._service,
self._messages.AiplatformProjectsLocationsCustomJobsListRequest(
parent=region),
field='customJobs',
batch_size_attribute='pageSize',
limit=limit)
def Get(self, name):
request = self._messages.AiplatformProjectsLocationsCustomJobsGetRequest(
name=name)
return self._service.Get(request)
def Cancel(self, name):
request = self._messages.AiplatformProjectsLocationsCustomJobsCancelRequest(
name=name)
return self._service.Cancel(request)
def CheckJobComplete(self, name):
"""Returns a function to decide if log fetcher should continue polling.
Args:
name: String id of job.
Returns:
A one-argument function decides if log fetcher should continue.
"""
request = self._messages.AiplatformProjectsLocationsCustomJobsGetRequest(
name=name)
response = self._service.Get(request)
def ShouldContinue(periods_without_logs):
if periods_without_logs <= 1:
return True
return response.endTime is None
return ShouldContinue
def ImportResourceMessage(self, yaml_file, message_name):
"""Import a messages class instance typed by name from a YAML file."""
data = console_io.ReadFromFileOrStdin(yaml_file, binary=False)
message_type = self.GetMessage(message_name)
return export_util.Import(message_type=message_type, stream=data)

View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform deployment resource pools API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags
class DeploymentResourcePoolsClient(object):
"""High-level client for the AI Platform deployment resource pools surface."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version]
)
self.messages = messages or self.client.MESSAGES_MODULE
def CreateBeta(
self,
location_ref,
deployment_resource_pool_id,
autoscaling_metric_specs=None,
accelerator_dict=None,
min_replica_count=None,
max_replica_count=None,
machine_type=None,
tpu_topology=None,
multihost_gpu_node_count=None,
reservation_affinity=None,
spot=False,
required_replica_count=None,
):
"""Creates a new deployment resource pool using v1beta1 API.
Args:
location_ref: Resource, the parsed location to create a deployment
resource pool.
deployment_resource_pool_id: str, The ID to use for the
DeploymentResourcePool, which will become the final component of the
DeploymentResourcePool's resource name.
autoscaling_metric_specs: dict or None, the metric specification that
defines the target resource utilization for calculating the desired
replica count.
accelerator_dict: dict or None, the accelerator attached to the deployment
resource pool from args.
min_replica_count: int or None, The minimum number of machine replicas
this deployment resource pool will be always deployed on. This value
must be greater than or equal to 1.
max_replica_count: int or None, The maximum number of replicas this
deployment resource pool may be deployed on when the traffic against it
increases.
machine_type: str or None, Immutable. The type of the machine.
tpu_topology: str or None, the topology of the TPU to serve the model.
multihost_gpu_node_count: int or None, the number of nodes per replica for
multihost GPU deployments.
reservation_affinity: dict or None, the reservation affinity of the
deployed model which specifies which reservations the deployed model can
use.
spot: bool, whether or not deploy the model on spot resources.
required_replica_count: int or None, The required number of replicas this
deployment resource pool will be considered successfully deployed. This
value must be greater than or equal to 1 and less than or equal to
min_replica_count.
Returns:
A long-running operation for Create.
"""
machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec()
if machine_type is not None:
machine_spec.machineType = machine_type
if tpu_topology is not None:
machine_spec.tpuTopology = tpu_topology
if multihost_gpu_node_count is not None:
machine_spec.multihostGpuNodeCount = multihost_gpu_node_count
accelerator = flags.ParseAcceleratorFlag(
accelerator_dict, constants.BETA_VERSION
)
if accelerator is not None:
machine_spec.acceleratorType = accelerator.acceleratorType
machine_spec.acceleratorCount = accelerator.acceleratorCount
if reservation_affinity is not None:
machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag(
reservation_affinity, constants.BETA_VERSION
)
dedicated = self.messages.GoogleCloudAiplatformV1beta1DedicatedResources(
machineSpec=machine_spec, spot=spot
)
dedicated.minReplicaCount = min_replica_count or 1
if max_replica_count is not None:
dedicated.maxReplicaCount = max_replica_count
if required_replica_count is not None:
dedicated.requiredReplicaCount = required_replica_count
if autoscaling_metric_specs is not None:
autoscaling_metric_specs_list = []
for name, target in sorted(autoscaling_metric_specs.items()):
autoscaling_metric_specs_list.append(
self.messages.GoogleCloudAiplatformV1beta1AutoscalingMetricSpec(
metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[name],
target=target
)
)
dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list
pool = self.messages.GoogleCloudAiplatformV1beta1DeploymentResourcePool(
dedicatedResources=dedicated
)
pool_request = self.messages.GoogleCloudAiplatformV1beta1CreateDeploymentResourcePoolRequest(
deploymentResourcePool=pool,
deploymentResourcePoolId=deployment_resource_pool_id
)
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1CreateDeploymentResourcePoolRequest=pool_request
)
operation = self.client.projects_locations_deploymentResourcePools.Create(
req
)
return operation
def DeleteBeta(self, deployment_resource_pool_ref):
"""Deletes a deployment resource pool using v1beta1 API.
Args:
deployment_resource_pool_ref: str, The deployment resource pool to delete.
Returns:
A GoogleProtobufEmpty response message for delete.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsDeleteRequest(
name=deployment_resource_pool_ref.RelativeName()
)
operation = self.client.projects_locations_deploymentResourcePools.Delete(
req
)
return operation
def DescribeBeta(self, deployment_resource_pool_ref):
"""Describes a deployment resource pool using v1beta1 API.
Args:
deployment_resource_pool_ref: str, Deployment resource pool to describe.
Returns:
GoogleCloudAiplatformV1beta1DeploymentResourcePool response message.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsGetRequest(
name=deployment_resource_pool_ref.RelativeName()
)
response = self.client.projects_locations_deploymentResourcePools.Get(req)
return response
def ListBeta(self, location_ref):
"""Lists deployment resource pools using v1beta1 API.
Args:
location_ref: Resource, the parsed location to list deployment resource
pools.
Returns:
Nested attribute containing list of deployment resource pools.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsListRequest(
parent=location_ref.RelativeName()
)
return list_pager.YieldFromList(
self.client.projects_locations_deploymentResourcePools,
req,
field='deploymentResourcePools',
batch_size_attribute='pageSize'
)
def QueryDeployedModelsBeta(self, deployment_resource_pool_ref):
"""Queries deployed models sharing a specified deployment resource pool using v1beta1 API.
Args:
deployment_resource_pool_ref: str, Deployment resource pool to query.
Returns:
GoogleCloudAiplatformV1beta1QueryDeployedModelsResponse message.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsQueryDeployedModelsRequest(
deploymentResourcePool=deployment_resource_pool_ref.RelativeName()
)
response = self.client.projects_locations_deploymentResourcePools.QueryDeployedModels(
req
)
return response

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library for streaming prediction results from the Vertex AI PredictionService API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
from googlecloudsdk.api_lib.util import apis
class PredictionStreamer(object):
"""Streams prediction responses using gRPC."""
def __init__(self, version):
self.client = apis.GetGapicClientInstance('aiplatform', version)
def StreamDirectPredict(
self,
endpoint,
inputs,
parameters,
):
"""Streams prediction results from the Cloud Vertex AI PredictionService API.
Args:
endpoint: The name of the endpoint to stream predictions from.
inputs: The inputs to send to the endpoint.
parameters: The parameters to send to the endpoint.
Yields:
Streamed prediction results.
"""
# Construct the request.
request = self.client.types.StreamDirectPredictRequest(endpoint=endpoint)
for curr_input in inputs:
request.inputs.append(
self.client.types.Tensor.from_json(json.dumps(curr_input))
)
request.parameters = self.client.types.Tensor.from_json(
json.dumps(parameters)
)
for prediction in self.client.prediction.stream_direct_predict(
iter([request])
):
yield prediction
def StreamDirectRawPredict(
self,
endpoint,
method_name,
input,
):
"""Streams prediction results from the Cloud Vertex AI PredictionService API.
Args:
endpoint: The name of the endpoint to stream predictions from.
method_name: The name of the method to call.
input: The input bytes to send to the endpoint.
Yields:
Streamed prediction results.
"""
# Construct the request.
request = self.client.types.StreamDirectRawPredictRequest(
endpoint=endpoint, method_name=method_name, input=input
)
for prediction in self.client.prediction.stream_direct_raw_predict(
iter([request])
):
yield prediction

View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying hptuning-jobs in AI platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import messages as messages_util
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.core import yaml
def GetAlgorithmEnum(version=constants.BETA_VERSION):
messages = apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
if version == constants.GA_VERSION:
return messages.GoogleCloudAiplatformV1StudySpec.AlgorithmValueValuesEnum
else:
return messages.GoogleCloudAiplatformV1beta1StudySpec.AlgorithmValueValuesEnum
class HpTuningJobsClient(object):
"""Client used for interacting with HyperparameterTuningJob endpoint."""
def __init__(self, version):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._service = client.projects_locations_hyperparameterTuningJobs
self.version = version
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def _GetMessage(self, message_name):
"""Returns the API messsages class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def HyperparameterTuningJobMessage(self):
"""Returns the HyperparameterTuningJob resource message."""
return self._GetMessage('HyperparameterTuningJob')
def AlgorithmEnum(self):
"""Returns enum message representing Algorithm."""
return self._GetMessage('StudySpec').AlgorithmValueValuesEnum
def Create(
self,
config_path,
display_name,
parent=None,
max_trial_count=None,
parallel_trial_count=None,
algorithm=None,
kms_key_name=None,
network=None,
service_account=None,
enable_web_access=False,
enable_dashboard_access=False,
labels=None):
"""Creates a hyperparameter tuning job with given parameters.
Args:
config_path: str, the file path of the hyperparameter tuning job
configuration.
display_name: str, the display name of the created hyperparameter tuning
job.
parent: str, parent of the created hyperparameter tuning job. e.g.
/projects/xxx/locations/xxx/
max_trial_count: int, the desired total number of Trials. The default
value is 1.
parallel_trial_count: int, the desired number of Trials to run in
parallel. The default value is 1.
algorithm: AlgorithmValueValuesEnum, the search algorithm specified for
the Study.
kms_key_name: str, A customer-managed encryption key to use for the
hyperparameter tuning job.
network: str, user network to which the job should be peered with
(overrides yaml file)
service_account: str, A service account (email address string) to use for
the job.
enable_web_access: bool, Whether to enable the interactive shell for the
job.
enable_dashboard_access: bool, Whether to enable the dashboard defined for
the job.
labels: LabelsValues, map-like user-defined metadata to organize the
hp-tuning jobs.
Returns:
Created hyperparameter tuning job.
"""
job_spec = self.HyperparameterTuningJobMessage()
if config_path:
data = yaml.load_path(config_path)
if data:
job_spec = messages_util.DictToMessageWithErrorCheck(
data, self.HyperparameterTuningJobMessage())
if not job_spec.maxTrialCount and not max_trial_count:
job_spec.maxTrialCount = 1
elif max_trial_count:
job_spec.maxTrialCount = max_trial_count
if not job_spec.parallelTrialCount and not parallel_trial_count:
job_spec.parallelTrialCount = 1
elif parallel_trial_count:
job_spec.parallelTrialCount = parallel_trial_count
if network:
job_spec.trialJobSpec.network = network
if service_account:
job_spec.trialJobSpec.serviceAccount = service_account
if enable_web_access:
job_spec.trialJobSpec.enableWebAccess = enable_web_access
if enable_dashboard_access:
job_spec.trialJobSpec.enableDashboardAccess = enable_dashboard_access
if display_name:
job_spec.displayName = display_name
if algorithm and job_spec.studySpec:
job_spec.studySpec.algorithm = algorithm
if kms_key_name is not None:
job_spec.encryptionSpec = self._GetMessage('EncryptionSpec')(
kmsKeyName=kms_key_name)
if labels:
job_spec.labels = labels
if self.version == constants.GA_VERSION:
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsCreateRequest(
parent=parent,
googleCloudAiplatformV1HyperparameterTuningJob=job_spec)
else:
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsCreateRequest(
parent=parent,
googleCloudAiplatformV1beta1HyperparameterTuningJob=job_spec)
return self._service.Create(request)
def Get(self, name=None):
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsGetRequest(
name=name)
return self._service.Get(request)
def Cancel(self, name=None):
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsCancelRequest(
name=name)
return self._service.Cancel(request)
def List(self, limit=None, region=None):
return list_pager.YieldFromList(
self._service,
self._messages
.AiplatformProjectsLocationsHyperparameterTuningJobsListRequest(
parent=region),
field='hyperparameterTuningJobs',
batch_size_attribute='pageSize',
limit=limit)
def CheckJobComplete(self, name):
"""Returns a function to decide if log fetcher should continue polling.
Args:
name: String id of job.
Returns:
A one-argument function decides if log fetcher should continue.
"""
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsGetRequest(
name=name)
response = self._service.Get(request)
def ShouldContinue(periods_without_logs):
if periods_without_logs <= 1:
return True
return response.endTime is None
return ShouldContinue

View File

@@ -0,0 +1,518 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform index endpoints API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def _ParseIndex(index_id, location_id):
"""Parses a index ID into a index resource object."""
return resources.REGISTRY.Parse(
index_id,
params={
'locationsId': location_id,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='aiplatform.projects.locations.indexes')
class IndexEndpointsClient(object):
"""High-level client for the AI Platform index endpoints surface."""
def __init__(self, client=None, messages=None, version=constants.GA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_indexEndpoints
def CreateBeta(self, location_ref, args):
"""Create a new index endpoint."""
labels = labels_util.ParseCreateArgs(
args,
self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint.LabelsValue)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
private_service_connect_config = None
if args.enable_private_service_connect:
private_service_connect_config = (
self.messages.GoogleCloudAiplatformV1beta1PrivateServiceConnectConfig(
enablePrivateServiceConnect=args.enable_private_service_connect,
projectAllowlist=(args.project_allowlist
if args.project_allowlist else [])
)
)
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=args.public_endpoint_enabled,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
elif args.network is not None:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
displayName=args.display_name,
description=args.description,
network=args.network,
labels=labels,
encryptionSpec=encryption_spec,
),
)
else:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=True,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
return self._service.Create(req)
def Create(self, location_ref, args):
"""Create a new v1 index endpoint."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1IndexEndpoint.LabelsValue)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
private_service_connect_config = None
if args.enable_private_service_connect:
private_service_connect_config = (
self.messages.GoogleCloudAiplatformV1PrivateServiceConnectConfig(
enablePrivateServiceConnect=args.enable_private_service_connect,
projectAllowlist=(args.project_allowlist
if args.project_allowlist else []),
)
)
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=self.messages.GoogleCloudAiplatformV1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=args.public_endpoint_enabled,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
elif args.network is not None:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=self.messages.GoogleCloudAiplatformV1IndexEndpoint(
displayName=args.display_name,
description=args.description,
network=args.network,
labels=labels,
encryptionSpec=encryption_spec,
),
)
else:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=self.messages.GoogleCloudAiplatformV1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=True,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
return self._service.Create(req)
def PatchBeta(self, index_endpoint_ref, args):
"""Update an index endpoint."""
index_endpoint = self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint()
update_mask = []
if args.display_name is not None:
index_endpoint.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index_endpoint.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_endpoint_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args,
self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint.LabelsValue,
GetLabels)
if labels_update.needs_update:
index_endpoint.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexEndpointsPatchRequest(
name=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=index_endpoint,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Patch(self, index_endpoint_ref, args):
"""Update an v1 index endpoint."""
index_endpoint = self.messages.GoogleCloudAiplatformV1IndexEndpoint()
update_mask = []
if args.display_name is not None:
index_endpoint.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index_endpoint.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_endpoint_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1IndexEndpoint.LabelsValue,
GetLabels)
if labels_update.needs_update:
index_endpoint.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexEndpointsPatchRequest(
name=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=index_endpoint,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def DeployIndexBeta(self, index_endpoint_ref, args):
"""Deploy an index to an index endpoint."""
index_ref = _ParseIndex(args.index, args.region)
deployed_index = self.messages.GoogleCloudAiplatformV1beta1DeployedIndex(
displayName=args.display_name,
id=args.deployed_index_id,
index=index_ref.RelativeName(),
)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.deployment_tier:
deployed_index.deploymentTier = self.messages.GoogleCloudAiplatformV1beta1DeployedIndex.DeploymentTierValueValuesEnum(
args.deployment_tier.upper())
if args.enable_access_logging is not None:
deployed_index.enableAccessLogging = args.enable_access_logging
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfig(
authProvider=auth_provider))
if args.machine_type is not None:
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1beta1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
deployed_index.dedicatedResources = dedicated_resources
else:
automatic_resources = (
self.messages.GoogleCloudAiplatformV1beta1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
deployed_index.automaticResources = automatic_resources
deploy_index_req = self.messages.GoogleCloudAiplatformV1beta1DeployIndexRequest(
deployedIndex=deployed_index)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsDeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1DeployIndexRequest=deploy_index_req)
return self._service.DeployIndex(request)
def DeployIndex(self, index_endpoint_ref, args):
"""Deploy an v1 index to an index endpoint."""
index_ref = _ParseIndex(args.index, args.region)
deployed_index = self.messages.GoogleCloudAiplatformV1DeployedIndex(
displayName=args.display_name,
id=args.deployed_index_id,
index=index_ref.RelativeName(),
enableAccessLogging=args.enable_access_logging
)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.deployment_tier:
deployed_index.deploymentTier = self.messages.GoogleCloudAiplatformV1DeployedIndex.DeploymentTierValueValuesEnum(
args.deployment_tier.upper())
# JWT Authentication config
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfig(
authProvider=auth_provider))
# PSC automation configs
if args.psc_automation_configs is not None:
deployed_index.pscAutomationConfigs = []
for psc_automation_config in args.psc_automation_configs:
deployed_index.pscAutomationConfigs.append(
self.messages.GoogleCloudAiplatformV1PSCAutomationConfig(
projectId=psc_automation_config['project-id'],
network=psc_automation_config['network'],
)
)
if args.machine_type is not None:
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
deployed_index.dedicatedResources = dedicated_resources
else:
automatic_resources = (
self.messages.GoogleCloudAiplatformV1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
deployed_index.automaticResources = automatic_resources
deploy_index_req = self.messages.GoogleCloudAiplatformV1DeployIndexRequest(
deployedIndex=deployed_index)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsDeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1DeployIndexRequest=deploy_index_req)
return self._service.DeployIndex(request)
def UndeployIndexBeta(self, index_endpoint_ref, args):
"""Undeploy an index to an index endpoint."""
undeploy_index_req = self.messages.GoogleCloudAiplatformV1beta1UndeployIndexRequest(
deployedIndexId=args.deployed_index_id)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsUndeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1UndeployIndexRequest=undeploy_index_req)
return self._service.UndeployIndex(request)
def UndeployIndex(self, index_endpoint_ref, args):
"""Undeploy an v1 index to an index endpoint."""
undeploy_index_req = self.messages.GoogleCloudAiplatformV1UndeployIndexRequest(
deployedIndexId=args.deployed_index_id)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsUndeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1UndeployIndexRequest=undeploy_index_req)
return self._service.UndeployIndex(request)
def MutateDeployedIndexBeta(self, index_endpoint_ref, args):
"""Mutate a deployed index from an index endpoint."""
deployed_index = self.messages.GoogleCloudAiplatformV1beta1DeployedIndex(
id=args.deployed_index_id,
enableAccessLogging=args.enable_access_logging,
)
if args.machine_type is not None:
deployed_index.dedicatedResources = self._GetDedicatedResourcesBeta(args)
else:
deployed_index.automaticResources = self._GetAutomaticResourcesBeta(args)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfig(
authProvider=auth_provider))
request = self.messages.AiplatformProjectsLocationsIndexEndpointsMutateDeployedIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1DeployedIndex=deployed_index)
return self._service.MutateDeployedIndex(request)
def MutateDeployedIndex(self, index_endpoint_ref, args):
"""Mutate a deployed index from an index endpoint."""
deployed_index = self.messages.GoogleCloudAiplatformV1DeployedIndex(
id=args.deployed_index_id,
enableAccessLogging=args.enable_access_logging,
)
if args.machine_type is not None:
deployed_index.dedicatedResources = self._GetDedicatedResources(args)
else:
deployed_index.automaticResources = self._GetAutomaticResources(args)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfig(
authProvider=auth_provider))
request = self.messages.AiplatformProjectsLocationsIndexEndpointsMutateDeployedIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1DeployedIndex=deployed_index)
return self._service.MutateDeployedIndex(request)
def Get(self, index_endpoint_ref):
request = self.messages.AiplatformProjectsLocationsIndexEndpointsGetRequest(
name=index_endpoint_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=None, region_ref=None):
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsIndexEndpointsListRequest(
parent=region_ref.RelativeName()),
field='indexEndpoints',
batch_size_attribute='pageSize',
limit=limit)
def Delete(self, index_endpoint_ref):
request = self.messages.AiplatformProjectsLocationsIndexEndpointsDeleteRequest(
name=index_endpoint_ref.RelativeName())
return self._service.Delete(request)
def _GetDedicatedResourcesBeta(self, args):
"""Construct dedicated resources for beta API."""
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1beta1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
return dedicated_resources
def _GetAutomaticResourcesBeta(self, args):
"""Construct automatic resources for beta API."""
automatic_resources = (
self.messages.GoogleCloudAiplatformV1beta1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
return automatic_resources
def _GetDedicatedResources(self, args):
"""Construct dedicated resources for GA API."""
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
return dedicated_resources
def _GetAutomaticResources(self, args):
"""Construct automatic resources for GA API."""
automatic_resources = (
self.messages.GoogleCloudAiplatformV1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
return automatic_resources

View File

@@ -0,0 +1,313 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform indexes API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import extra_types
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import messages as messages_util
from googlecloudsdk.calliope import exceptions as gcloud_exceptions
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import yaml
class IndexesClient(object):
"""High-level client for the AI Platform indexes surface."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_indexes
def _ReadIndexMetadata(self, metadata_file):
"""Parse json metadata file."""
if not metadata_file:
raise gcloud_exceptions.BadArgumentException(
'--metadata-file', 'Index metadata file must be specified.')
index_metadata = None
# Yaml is a superset of json, so parse json file as yaml.
data = yaml.load_path(metadata_file)
if data:
index_metadata = messages_util.DictToMessageWithErrorCheck(
data, extra_types.JsonValue)
return index_metadata
def Get(self, index_ref):
request = self.messages.AiplatformProjectsLocationsIndexesGetRequest(
name=index_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=None, region_ref=None):
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsIndexesListRequest(
parent=region_ref.RelativeName()),
field='indexes',
batch_size_attribute='pageSize',
limit=limit)
def CreateBeta(self, location_ref, args):
"""Create a new index."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1beta1Index.LabelsValue)
index_update_method = None
if args.index_update_method:
if args.index_update_method == 'stream-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1beta1Index.
IndexUpdateMethodValueValuesEnum.STREAM_UPDATE)
elif args.index_update_method == 'batch-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1beta1Index.
IndexUpdateMethodValueValuesEnum.BATCH_UPDATE)
else:
raise gcloud_exceptions.BadArgumentException(
'--index-update-method',
'Invalid index update method: {}'.format(args.index_update_method),
)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
req = self.messages.AiplatformProjectsLocationsIndexesCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1Index=self.messages
.GoogleCloudAiplatformV1beta1Index(
displayName=args.display_name,
description=args.description,
metadata=self._ReadIndexMetadata(args.metadata_file),
labels=labels,
indexUpdateMethod=index_update_method,
encryptionSpec=encryption_spec
))
return self._service.Create(req)
def Create(self, location_ref, args):
"""Create a new v1 index."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1Index.LabelsValue)
index_update_method = None
if args.index_update_method:
if args.index_update_method == 'stream-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1Index
.IndexUpdateMethodValueValuesEnum.STREAM_UPDATE)
elif args.index_update_method == 'batch-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1Index.IndexUpdateMethodValueValuesEnum.BATCH_UPDATE
)
else:
raise gcloud_exceptions.BadArgumentException(
'--index-update-method',
'Invalid index update method: {}'.format(args.index_update_method),
)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
req = self.messages.AiplatformProjectsLocationsIndexesCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1Index=self.messages.GoogleCloudAiplatformV1Index(
displayName=args.display_name,
description=args.description,
metadata=self._ReadIndexMetadata(args.metadata_file),
labels=labels,
indexUpdateMethod=index_update_method,
encryptionSpec=encryption_spec
))
return self._service.Create(req)
def PatchBeta(self, index_ref, args):
"""Update an index."""
index = self.messages.GoogleCloudAiplatformV1beta1Index()
update_mask = []
if args.metadata_file is not None:
index.metadata = self._ReadIndexMetadata(args.metadata_file)
update_mask.append('metadata')
else:
if args.display_name is not None:
index.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1beta1Index.LabelsValue,
GetLabels)
if labels_update.needs_update:
index.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexesPatchRequest(
name=index_ref.RelativeName(),
googleCloudAiplatformV1beta1Index=index,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Patch(self, index_ref, args):
"""Update an v1 index."""
index = self.messages.GoogleCloudAiplatformV1Index()
update_mask = []
if args.metadata_file is not None:
index.metadata = self._ReadIndexMetadata(args.metadata_file)
update_mask.append('metadata')
else:
if args.display_name is not None:
index.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1Index.LabelsValue,
GetLabels)
if labels_update.needs_update:
index.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexesPatchRequest(
name=index_ref.RelativeName(),
googleCloudAiplatformV1Index=index,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Delete(self, index_ref):
request = self.messages.AiplatformProjectsLocationsIndexesDeleteRequest(
name=index_ref.RelativeName())
return self._service.Delete(request)
def RemoveDatapointsBeta(self, index_ref, args):
"""Remove data points from a v1beta1 index."""
if args.datapoint_ids and args.datapoints_from_file:
raise errors.ArgumentError(
'datapoint_ids and datapoints_from_file can not be set'
' at the same time.'
)
if args.datapoint_ids:
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1beta1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1beta1RemoveDatapointsRequest(
datapointIds=args.datapoint_ids))
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1beta1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1beta1RemoveDatapointsRequest(
datapointIds=data))
return self._service.RemoveDatapoints(req)
def RemoveDatapoints(self, index_ref, args):
"""Remove data points from a v1 index."""
if args.datapoint_ids and args.datapoints_from_file:
raise errors.ArgumentError(
'`--datapoint_ids` and `--datapoints_from_file` can not be set at the'
' same time.'
)
if args.datapoint_ids:
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1RemoveDatapointsRequest(
datapointIds=args.datapoint_ids))
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1RemoveDatapointsRequest(
datapointIds=data))
return self._service.RemoveDatapoints(req)
def UpsertDatapointsBeta(self, index_ref, args):
"""Upsert data points from a v1beta1 index."""
datapoints = []
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
for datapoint_json in data:
datapoint = messages_util.DictToMessageWithErrorCheck(
datapoint_json,
self.messages.GoogleCloudAiplatformV1beta1IndexDatapoint)
datapoints.append(datapoint)
update_mask = None
if args.update_mask:
update_mask = ','.join(args.update_mask)
req = self.messages.AiplatformProjectsLocationsIndexesUpsertDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1beta1UpsertDatapointsRequest=self.messages
.GoogleCloudAiplatformV1beta1UpsertDatapointsRequest(
datapoints=datapoints,
updateMask=update_mask))
return self._service.UpsertDatapoints(req)
def UpsertDatapoints(self, index_ref, args):
"""Upsert data points from a v1 index."""
datapoints = []
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
for datapoint_json in data:
datapoint = messages_util.DictToMessageWithErrorCheck(
datapoint_json,
self.messages.GoogleCloudAiplatformV1IndexDatapoint)
datapoints.append(datapoint)
update_mask = None
if args.update_mask:
update_mask = ','.join(args.update_mask)
req = self.messages.AiplatformProjectsLocationsIndexesUpsertDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1UpsertDatapointsRequest=self.messages
.GoogleCloudAiplatformV1UpsertDatapointsRequest(
datapoints=datapoints,
updateMask=update_mask))
return self._service.UpsertDatapoints(req)

View File

@@ -0,0 +1,515 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Vertex AI Model Garden APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags
_HF_WILDCARD_FILTER = 'is_hf_wildcard(true)'
_NATIVE_MODEL_FILTER = 'is_hf_wildcard(false)'
_VERIFIED_DEPLOYMENT_FILTER = (
'labels.VERIFIED_DEPLOYMENT_CONFIG=VERIFIED_DEPLOYMENT_SUCCEED'
)
def IsHuggingFaceModel(model_name: str) -> bool:
"""Returns whether the model is a Hugging Face model."""
return bool(re.match(r'^[^/]+/[^/@]+$', model_name))
def IsCustomWeightsModel(model: str) -> bool:
"""Returns whether the model is a custom weights model."""
return bool(re.match(r'^gs://', model))
def DeployCustomWeightsModel(
messages,
projects_locations_service,
model,
machine_type,
accelerator_type,
accelerator_count,
project,
location,
):
"""Deploys a custom weights model."""
deploy_request = messages.GoogleCloudAiplatformV1beta1DeployRequest()
deploy_request.customModel = (
messages.GoogleCloudAiplatformV1beta1DeployRequestCustomModel(
gcsUri=model
)
)
if machine_type:
deploy_request.deployConfig = messages.GoogleCloudAiplatformV1beta1DeployRequestDeployConfig(
dedicatedResources=messages.GoogleCloudAiplatformV1beta1DedicatedResources(
machineSpec=messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=machine_type,
acceleratorType=accelerator_type,
acceleratorCount=accelerator_count,
),
minReplicaCount=1,
),
)
request = messages.AiplatformProjectsLocationsDeployRequest(
destination=f'projects/{project}/locations/{location}',
googleCloudAiplatformV1beta1DeployRequest=deploy_request,
)
return projects_locations_service.Deploy(request)
class ModelGardenClient(object):
"""Client used for interacting with Model Garden APIs."""
def __init__(self, version=constants.BETA_VERSION):
client = apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version],
)
self._messages = client.MESSAGES_MODULE
self._publishers_models_service = client.publishers_models
self._projects_locations_service = client.projects_locations
def GetPublisherModel(
self,
model_name,
is_hugging_face_model=False,
include_equivalent_model_garden_model_deployment_configs=True,
hugging_face_token=None,
):
"""Get a publisher model.
Args:
model_name: The name of the model to get. The format should be
publishers/{publisher}/models/{model}
is_hugging_face_model: Whether the model is a hugging face model.
include_equivalent_model_garden_model_deployment_configs: Whether to
include equivalent Model Garden model deployment configs for Hugging
Face models.
hugging_face_token: The Hugging Face access token to access the model
artifacts for gated models unverified by Model Garden.
Returns:
A publisher model.
"""
request = self._messages.AiplatformPublishersModelsGetRequest(
name=model_name,
isHuggingFaceModel=is_hugging_face_model,
includeEquivalentModelGardenModelDeploymentConfigs=include_equivalent_model_garden_model_deployment_configs,
huggingFaceToken=hugging_face_token,
)
return self._publishers_models_service.Get(request)
def Deploy(
self,
project,
location,
model,
accept_eula,
accelerator_type,
accelerator_count,
machine_type,
endpoint_display_name,
hugging_face_access_token,
spot,
reservation_affinity,
use_dedicated_endpoint,
enable_fast_tryout,
container_image_uri=None,
container_command=None,
container_args=None,
container_env_vars=None,
container_ports=None,
container_grpc_ports=None,
container_predict_route=None,
container_health_route=None,
container_deployment_timeout_seconds=None,
container_shared_memory_size_mb=None,
container_startup_probe_exec=None,
container_startup_probe_period_seconds=None,
container_startup_probe_timeout_seconds=None,
container_health_probe_exec=None,
container_health_probe_period_seconds=None,
container_health_probe_timeout_seconds=None,
):
"""Deploy an open weight model.
Args:
project: The project to deploy the model to.
location: The location to deploy the model to.
model: The name of the model to deploy or its gcs uri for custom weights.
accept_eula: Whether to accept the end-user license agreement.
accelerator_type: The type of accelerator to use.
accelerator_count: The number of accelerators to use.
machine_type: The type of machine to use.
endpoint_display_name: The display name of the endpoint.
hugging_face_access_token: The Hugging Face access token.
spot: Whether to deploy the model on Spot VMs.
reservation_affinity: The reservation affinity to use.
use_dedicated_endpoint: Whether to use a dedicated endpoint.
enable_fast_tryout: Whether to enable fast tryout.
container_image_uri: Immutable. URI of the Docker image to be used as the
custom container for serving predictions. This URI must identify an
image in Artifact Registry or Container Registry. Learn more about the
[container publishing requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#publishing), including
permissions requirements for the Vertex AI Service Agent. The container
image is ingested upon ModelService.UploadModel, stored internally, and
this original path is afterwards not used. To learn about the
requirements for the Docker image itself, see [Custom container
requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#). You can use the URI
to one of Vertex AI's [pre-built container images for
prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-
built-containers) in this field.
container_command: Specifies the command that runs when the container
starts. This overrides the container's [ENTRYPOINT](https://docs.docker.
com/engine/reference/builder/#entrypoint). Specify this field as an
array of executable and arguments, similar to a Docker `ENTRYPOINT`'s
"exec" form, not its "shell" form. If you do not specify this field,
then the container's `ENTRYPOINT` runs, in conjunction with the args
field or the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd), if
either exists. If this field is not specified and the container does not
have an `ENTRYPOINT`, then refer to the Docker documentation about [how
`CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). If you specify this field, then you
can also specify the `args` field to provide additional arguments for
this command. However, if you specify this field, then the container's
`CMD` is ignored. See the [Kubernetes documentation about how the
`command` and `args` fields interact with a container's `ENTRYPOINT` and
`CMD`](https://kubernetes.io/docs/tasks/inject-data-application/define-
command-argument-container/#notes). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `command` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_args: Specifies arguments for the command that runs when the
container starts. This overrides the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd). Specify
this field as an array of executable and arguments, similar to a Docker
`CMD`'s "default parameters" form. If you don't specify this field but
do specify the command field, then the command from the `command` field
runs without any additional arguments. See the [Kubernetes documentation
about how the `command` and `args` fields interact with a container's
`ENTRYPOINT` and `CMD`](https://kubernetes.io/docs/tasks/inject-data-
application/define-command-argument-container/#notes). If you don't
specify this field and don't specify the `command` field, then the
container's
[`ENTRYPOINT`](https://docs.docker.com/engine/reference/builder/#cmd)
and `CMD` determine what runs based on their default behavior. See the
Docker documentation about [how `CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `args` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core)..
container_env_vars: List of environment variables to set in the container.
After the container starts running, code running in the container can
read these environment variables. Additionally, the command and args
fields can reference these variables. Later entries in this list can
also reference earlier entries. For example, the following example sets
the variable `VAR_2` to have the value `foo bar`: ```json [ { "name":
"VAR_1", "value": "foo" }, { "name": "VAR_2", "value": "$(VAR_1) bar" }
] ``` If you switch the order of the variables in the example, then the
expansion does not occur. This field corresponds to the `env` field of
the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_ports: List of ports to expose from the container. Vertex AI
sends any http prediction requests that it receives to the first port on
this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, it defaults to following value: ```json [ { "containerPort":
8080 } ] ``` Vertex AI does not use ports other than the first one
listed. This field corresponds to the `ports` field of the Kubernetes
Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_grpc_ports: List of ports to expose from the container. Vertex
AI sends any grpc prediction requests that it receives to the first port
on this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, gRPC requests to the container will be disabled. Vertex AI
does not use ports other than the first one listed. This field
corresponds to the `ports` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_predict_route: HTTP path on the container to send prediction
requests to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the container's IP
address and port. Vertex AI then returns the container's response in the
API response. For example, if you set this field to `/foo`, then when
Vertex AI receives a prediction request, it forwards the request body in
a POST request to the `/foo` path on the port of your container
specified by the first value of this `ModelContainerSpec`'s ports field.
If you don't specify this field, it defaults to the following value when
you deploy this Model to an Endpoint:
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict The
placeholders in this value are replaced as follows: * ENDPOINT: The last
segment (following `endpoints/`)of the Endpoint.name][] field of the
Endpoint where this Model has been deployed. (Vertex AI makes this value
available to your container code as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_health_route: HTTP path on the container to send health checks
to. Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is healthy.
Read more about [health checks](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#health). For example,
if you set this field to `/bar`, then Vertex AI intermittently sends a
GET request to the `/bar` path on the port of your container specified
by the first value of this `ModelContainerSpec`'s ports field. If you
don't specify this field, it defaults to the following value when you
deploy this Model to an Endpoint: /v1/endpoints/ENDPOINT/deployedModels/
DEPLOYED_MODEL:predict The placeholders in this value are replaced as
follows * ENDPOINT: The last segment (following `endpoints/`)of the
Endpoint.name][] field of the Endpoint where this Model has been
deployed. (Vertex AI makes this value available to your container code
as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_deployment_timeout_seconds (int): Deployment timeout in seconds.
container_shared_memory_size_mb (int): The amount of the VM memory to
reserve as the shared memory for the model in megabytes.
container_startup_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by startup probe. An example of this argument would be
["cat", "/tmp/healthy"]
container_startup_probe_period_seconds (int): How often (in seconds) to
perform the startup probe. Default to 10 seconds. Minimum value is 1.
container_startup_probe_timeout_seconds (int): Number of seconds after
which the startup probe times out. Defaults to 1 second. Minimum value
is 1.
container_health_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by health probe. An example of this argument would be ["cat",
"/tmp/healthy"]
container_health_probe_period_seconds (int): How often (in seconds) to
perform the health probe. Default to 10 seconds. Minimum value is 1.
container_health_probe_timeout_seconds (int): Number of seconds after
which the health probe times out. Defaults to 1 second. Minimum value is
1.
Returns:
The deploy long-running operation.
"""
container_spec = None
if container_image_uri:
container_spec = (
self._messages.GoogleCloudAiplatformV1beta1ModelContainerSpec(
healthRoute=container_health_route,
imageUri=container_image_uri,
predictRoute=container_predict_route,
)
)
if container_command:
container_spec.command = container_command
if container_args:
container_spec.args = container_args
if container_env_vars:
container_spec.env = [
self._messages.GoogleCloudAiplatformV1beta1EnvVar(
name=k, value=container_env_vars[k]
)
for k in container_env_vars
]
if container_ports:
container_spec.ports = [
self._messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_ports
]
if container_grpc_ports:
container_spec.grpcPorts = [
self._messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_grpc_ports
]
if container_deployment_timeout_seconds:
container_spec.deploymentTimeout = (
str(container_deployment_timeout_seconds) + 's'
)
if container_shared_memory_size_mb:
container_spec.sharedMemorySizeMb = container_shared_memory_size_mb
if (
container_startup_probe_exec
or container_startup_probe_period_seconds
or container_startup_probe_timeout_seconds
):
startup_probe_exec = None
if container_startup_probe_exec:
startup_probe_exec = (
self._messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_startup_probe_exec
)
)
container_spec.startupProbe = (
self._messages.GoogleCloudAiplatformV1beta1Probe(
exec_=startup_probe_exec,
periodSeconds=container_startup_probe_period_seconds,
timeoutSeconds=container_startup_probe_timeout_seconds,
)
)
if (
container_health_probe_exec
or container_health_probe_period_seconds
or container_health_probe_timeout_seconds
):
health_probe_exec = None
if container_health_probe_exec:
health_probe_exec = (
self._messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_health_probe_exec
)
)
container_spec.healthProbe = (
self._messages.GoogleCloudAiplatformV1beta1Probe(
exec_=health_probe_exec,
periodSeconds=container_health_probe_period_seconds,
timeoutSeconds=container_health_probe_timeout_seconds,
)
)
if IsCustomWeightsModel(model):
return DeployCustomWeightsModel(
self._messages,
self._projects_locations_service,
model,
machine_type,
accelerator_type,
accelerator_count,
project,
location,
)
elif IsHuggingFaceModel(model):
deploy_request = self._messages.GoogleCloudAiplatformV1beta1DeployRequest(
huggingFaceModelId=model
)
else:
deploy_request = self._messages.GoogleCloudAiplatformV1beta1DeployRequest(
publisherModelName=model
)
deploy_request.modelConfig = (
self._messages.GoogleCloudAiplatformV1beta1DeployRequestModelConfig(
huggingFaceAccessToken=hugging_face_access_token,
acceptEula=accept_eula,
containerSpec=container_spec,
)
)
deploy_request.endpointConfig = (
self._messages.GoogleCloudAiplatformV1beta1DeployRequestEndpointConfig(
endpointDisplayName=endpoint_display_name,
dedicatedEndpointEnabled=use_dedicated_endpoint,
)
)
deploy_request.deployConfig = self._messages.GoogleCloudAiplatformV1beta1DeployRequestDeployConfig(
dedicatedResources=self._messages.GoogleCloudAiplatformV1beta1DedicatedResources(
machineSpec=self._messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=machine_type,
acceleratorType=accelerator_type,
acceleratorCount=accelerator_count,
reservationAffinity=flags.ParseReservationAffinityFlag(
reservation_affinity, constants.BETA_VERSION
),
),
minReplicaCount=1,
spot=spot,
),
fastTryoutEnabled=enable_fast_tryout,
)
request = self._messages.AiplatformProjectsLocationsDeployRequest(
destination=f'projects/{project}/locations/{location}',
googleCloudAiplatformV1beta1DeployRequest=deploy_request,
)
return self._projects_locations_service.Deploy(request)
def ListPublisherModels(
self,
limit=None,
batch_size=100,
list_hf_models=False,
model_filter=None,
):
"""List publisher models in Model Garden.
Args:
limit: The maximum number of items to list. None if all available records
should be yielded.
batch_size: The number of items to list per page.
list_hf_models: Whether to only list Hugging Face models.
model_filter: The filter on model name to apply on server-side.
Returns:
The list of publisher models in Model Garden..
"""
filter_str = _NATIVE_MODEL_FILTER
if list_hf_models:
filter_str = ' AND '.join(
[_HF_WILDCARD_FILTER, _VERIFIED_DEPLOYMENT_FILTER]
)
if model_filter:
filter_str = (
f'{filter_str} AND (model_user_id=~"(?i).*{model_filter}.*" OR'
f' display_name=~"(?i).*{model_filter}.*")'
)
return list_pager.YieldFromList(
self._publishers_models_service,
self._messages.AiplatformPublishersModelsListRequest(
parent='publishers/*',
listAllVersions=True,
filter=filter_str,
),
field='publisherModels',
batch_size_attribute='pageSize',
batch_size=batch_size,
limit=limit,
)

View File

@@ -0,0 +1,528 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform model monitoring jobs API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
from apitools.base.py import encoding
from apitools.base.py import extra_types
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import messages as messages_util
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai import model_monitoring_jobs_util
from googlecloudsdk.command_lib.ai import validation as common_validation
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
import six
def _ParseEndpoint(endpoint_id, region_ref):
"""Parses a endpoint ID into a endpoint resource object."""
region = region_ref.AsDict()['locationsId']
return resources.REGISTRY.Parse(
endpoint_id,
params={
'locationsId': region,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='aiplatform.projects.locations.endpoints')
def _ParseDataset(dataset_id, region_ref):
"""Parses a dataset ID into a dataset resource object."""
region = region_ref.AsDict()['locationsId']
return resources.REGISTRY.Parse(
dataset_id,
params={
'locationsId': region,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='aiplatform.projects.locations.datasets')
class ModelMonitoringJobsClient(object):
"""High-level client for the AI Platform model deployment monitoring jobs surface."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_modelDeploymentMonitoringJobs
self._version = version
def _ConstructDriftThresholds(self, feature_thresholds,
feature_attribution_thresholds):
"""Construct drift thresholds from user input.
Args:
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key:feature_name, value:
attribution score thresholds.
Returns:
PredictionDriftDetectionConfig
"""
prediction_drift_detection = api_util.GetMessage(
'ModelMonitoringObjectiveConfigPredictionDriftDetectionConfig',
self._version)()
additional_properties = []
attribution_additional_properties = []
if feature_thresholds:
for key, value in feature_thresholds.items():
threshold = 0.3 if not value else float(value)
additional_properties.append(prediction_drift_detection
.DriftThresholdsValue().AdditionalProperty(
key=key,
value=api_util.GetMessage(
'ThresholdConfig',
self._version)(value=threshold)))
prediction_drift_detection.driftThresholds = prediction_drift_detection.DriftThresholdsValue(
additionalProperties=additional_properties)
if feature_attribution_thresholds:
for key, value in feature_attribution_thresholds.items():
threshold = 0.3 if not value else float(value)
attribution_additional_properties.append(
prediction_drift_detection.AttributionScoreDriftThresholdsValue(
).AdditionalProperty(
key=key,
value=api_util.GetMessage('ThresholdConfig',
self._version)(value=threshold)))
prediction_drift_detection.attributionScoreDriftThresholds = prediction_drift_detection.AttributionScoreDriftThresholdsValue(
additionalProperties=attribution_additional_properties)
return prediction_drift_detection
def _ConstructSkewThresholds(self, feature_thresholds,
feature_attribution_thresholds):
"""Construct skew thresholds from user input.
Args:
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key:feature_name, value:
attribution score thresholds.
Returns:
TrainingPredictionSkewDetectionConfig
"""
training_prediction_skew_detection = api_util.GetMessage(
'ModelMonitoringObjectiveConfigTrainingPredictionSkewDetectionConfig',
self._version)()
additional_properties = []
attribution_additional_properties = []
if feature_thresholds:
for key, value in feature_thresholds.items():
threshold = 0.3 if not value else float(value)
additional_properties.append(training_prediction_skew_detection
.SkewThresholdsValue().AdditionalProperty(
key=key,
value=api_util.GetMessage(
'ThresholdConfig',
self._version)(value=threshold)))
training_prediction_skew_detection.skewThresholds = training_prediction_skew_detection.SkewThresholdsValue(
additionalProperties=additional_properties)
if feature_attribution_thresholds:
for key, value in feature_attribution_thresholds.items():
threshold = 0.3 if not value else float(value)
attribution_additional_properties.append(
training_prediction_skew_detection
.AttributionScoreSkewThresholdsValue().AdditionalProperty(
key=key,
value=api_util.GetMessage('ThresholdConfig',
self._version)(value=threshold)))
training_prediction_skew_detection.attributionScoreSkewThresholds = training_prediction_skew_detection.AttributionScoreSkewThresholdsValue(
additionalProperties=attribution_additional_properties)
return training_prediction_skew_detection
def _ConstructObjectiveConfigForUpdate(self, existing_monitoring_job,
feature_thresholds,
feature_attribution_thresholds):
"""Construct monitoring objective config.
Update the feature thresholds for skew/drift detection to all the existing
deployed models under the job.
Args:
existing_monitoring_job: Existing monitoring job.
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key: feature_name, value:
attribution score thresholds.
Returns:
A list of model monitoring objective config.
"""
prediction_drift_detection = self._ConstructDriftThresholds(
feature_thresholds, feature_attribution_thresholds)
training_prediction_skew_detection = self._ConstructSkewThresholds(
feature_thresholds, feature_attribution_thresholds)
objective_configs = []
for objective_config in existing_monitoring_job.modelDeploymentMonitoringObjectiveConfigs:
if objective_config.objectiveConfig.trainingPredictionSkewDetectionConfig:
if training_prediction_skew_detection.skewThresholds:
objective_config.objectiveConfig.trainingPredictionSkewDetectionConfig.skewThresholds = training_prediction_skew_detection.skewThresholds
if training_prediction_skew_detection.attributionScoreSkewThresholds:
objective_config.objectiveConfig.trainingPredictionSkewDetectionConfig.attributionScoreSkewThresholds = training_prediction_skew_detection.attributionScoreSkewThresholds
if objective_config.objectiveConfig.predictionDriftDetectionConfig:
if prediction_drift_detection.driftThresholds:
objective_config.objectiveConfig.predictionDriftDetectionConfig.driftThresholds = prediction_drift_detection.driftThresholds
if prediction_drift_detection.attributionScoreDriftThresholds:
objective_config.objectiveConfig.predictionDriftDetectionConfig.attributionScoreDriftThresholds = prediction_drift_detection.attributionScoreDriftThresholds
if training_prediction_skew_detection.attributionScoreSkewThresholds or prediction_drift_detection.attributionScoreDriftThresholds:
objective_config.objectiveConfig.explanationConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfigExplanationConfig', self._version)(
enableFeatureAttributes=True)
objective_configs.append(objective_config)
return objective_configs
def _ConstructObjectiveConfigForCreate(self, location_ref, endpoint_name,
feature_thresholds,
feature_attribution_thresholds,
dataset, bigquery_uri, data_format,
gcs_uris, target_field,
training_sampling_rate):
"""Construct monitoring objective config.
Apply the feature thresholds for skew or drift detection to all the deployed
models under the endpoint.
Args:
location_ref: Location reference.
endpoint_name: Endpoint resource name.
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key: feature_name, value:
attribution score thresholds.
dataset: Vertex AI Dataset Id.
bigquery_uri: The BigQuery table of the unmanaged Dataset used to train
this Model.
data_format: Google Cloud Storage format, supported format: csv,
tf-record.
gcs_uris: The Google Cloud Storage uri of the unmanaged Dataset used to
train this Model.
target_field: The target field name the model is to predict.
training_sampling_rate: Training Dataset sampling rate.
Returns:
A list of model monitoring objective config.
"""
objective_config_template = api_util.GetMessage(
'ModelDeploymentMonitoringObjectiveConfig', self._version)()
if feature_thresholds or feature_attribution_thresholds:
if dataset or bigquery_uri or gcs_uris or data_format:
training_dataset = api_util.GetMessage(
'ModelMonitoringObjectiveConfigTrainingDataset', self._version)()
if target_field is None:
raise errors.ArgumentError(
"Target field must be provided if you'd like to do training-prediction skew detection."
)
training_dataset.targetField = target_field
training_dataset.loggingSamplingStrategy = api_util.GetMessage(
'SamplingStrategy', self._version)(
randomSampleConfig=api_util.GetMessage(
'SamplingStrategyRandomSampleConfig', self._version)(
sampleRate=training_sampling_rate))
if dataset:
training_dataset.dataset = _ParseDataset(dataset,
location_ref).RelativeName()
elif bigquery_uri:
training_dataset.bigquerySource = api_util.GetMessage(
'BigQuerySource', self._version)(
inputUri=bigquery_uri)
elif gcs_uris or data_format:
if gcs_uris is None:
raise errors.ArgumentError(
'Data format is defined but no Google Cloud Storage uris are provided. Please use --gcs-uris to provide training datasets.'
)
if data_format is None:
raise errors.ArgumentError(
'No Data format is defined for Google Cloud Storage training dataset. Please use --data-format to define the Data format.'
)
training_dataset.dataFormat = data_format
training_dataset.gcsSource = api_util.GetMessage(
'GcsSource', self._version)(
uris=gcs_uris)
training_prediction_skew_detection = self._ConstructSkewThresholds(
feature_thresholds, feature_attribution_thresholds)
objective_config_template.objectiveConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfig', self._version
)(trainingDataset=training_dataset,
trainingPredictionSkewDetectionConfig=training_prediction_skew_detection
)
else:
prediction_drift_detection = self._ConstructDriftThresholds(
feature_thresholds, feature_attribution_thresholds)
objective_config_template.objectiveConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfig', self._version)(
predictionDriftDetectionConfig=prediction_drift_detection)
if feature_attribution_thresholds:
objective_config_template.objectiveConfig.explanationConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfigExplanationConfig', self._version)(
enableFeatureAttributes=True)
get_endpoint_req = self.messages.AiplatformProjectsLocationsEndpointsGetRequest(
name=endpoint_name)
endpoint = self.client.projects_locations_endpoints.Get(get_endpoint_req)
objective_configs = []
for deployed_model in endpoint.deployedModels:
objective_config = copy.deepcopy(objective_config_template)
objective_config.deployedModelId = deployed_model.id
objective_configs.append(objective_config)
return objective_configs
def _ParseCreateLabels(self, args):
"""Parses create labels."""
return labels_util.ParseCreateArgs(
args,
api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)().LabelsValue)
def _ParseUpdateLabels(self, model_monitoring_job_ref, args):
"""Parses update labels."""
def GetLabels():
return self.Get(model_monitoring_job_ref).labels
return labels_util.ProcessUpdateArgsLazy(
args,
api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)().LabelsValue, GetLabels)
def Create(self, location_ref, args):
"""Creates a model deployment monitoring job."""
endpoint_ref = _ParseEndpoint(args.endpoint, location_ref)
job_spec = api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)()
kms_key_name = common_validation.GetAndValidateKmsKey(args)
if kms_key_name is not None:
job_spec.encryptionSpec = api_util.GetMessage('EncryptionSpec',
self._version)(
kmsKeyName=kms_key_name)
if args.monitoring_config_from_file:
data = yaml.load_path(args.monitoring_config_from_file)
if data:
job_spec = messages_util.DictToMessageWithErrorCheck(
data,
api_util.GetMessage('ModelDeploymentMonitoringJob', self._version))
else:
job_spec.modelDeploymentMonitoringObjectiveConfigs = self._ConstructObjectiveConfigForCreate(
location_ref, endpoint_ref.RelativeName(), args.feature_thresholds,
args.feature_attribution_thresholds, args.dataset, args.bigquery_uri,
args.data_format, args.gcs_uris, args.target_field,
args.training_sampling_rate)
job_spec.endpoint = endpoint_ref.RelativeName()
job_spec.displayName = args.display_name
job_spec.labels = self._ParseCreateLabels(args)
enable_anomaly_cloud_logging = False if args.anomaly_cloud_logging is None else args.anomaly_cloud_logging
job_spec.modelMonitoringAlertConfig = api_util.GetMessage(
'ModelMonitoringAlertConfig', self._version)(
enableLogging=enable_anomaly_cloud_logging,
emailAlertConfig=api_util.GetMessage(
'ModelMonitoringAlertConfigEmailAlertConfig',
self._version)(userEmails=args.emails),
notificationChannels=args.notification_channels)
job_spec.loggingSamplingStrategy = api_util.GetMessage(
'SamplingStrategy', self._version)(
randomSampleConfig=api_util.GetMessage(
'SamplingStrategyRandomSampleConfig', self._version)(
sampleRate=args.prediction_sampling_rate))
job_spec.modelDeploymentMonitoringScheduleConfig = api_util.GetMessage(
'ModelDeploymentMonitoringScheduleConfig', self._version)(
monitorInterval='{}s'.format(
six.text_type(3600 * int(args.monitoring_frequency))))
if args.predict_instance_schema:
job_spec.predictInstanceSchemaUri = args.predict_instance_schema
if args.analysis_instance_schema:
job_spec.analysisInstanceSchemaUri = args.analysis_instance_schema
if args.log_ttl:
job_spec.logTtl = '{}s'.format(six.text_type(86400 * int(args.log_ttl)))
if args.sample_predict_request:
instance_json = model_monitoring_jobs_util.ReadInstanceFromArgs(
args.sample_predict_request)
job_spec.samplePredictInstance = encoding.PyValueToMessage(
extra_types.JsonValue, instance_json)
if self._version == constants.BETA_VERSION:
return self._service.Create(
self.messages.
AiplatformProjectsLocationsModelDeploymentMonitoringJobsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1ModelDeploymentMonitoringJob=job_spec
))
else:
return self._service.Create(
self.messages.
AiplatformProjectsLocationsModelDeploymentMonitoringJobsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1ModelDeploymentMonitoringJob=job_spec))
def Patch(self, model_monitoring_job_ref, args):
"""Update a model deployment monitoring job."""
model_monitoring_job_to_update = api_util.GetMessage(
'ModelDeploymentMonitoringJob', self._version)()
update_mask = []
job_spec = api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)()
if args.monitoring_config_from_file:
data = yaml.load_path(args.monitoring_config_from_file)
if data:
job_spec = messages_util.DictToMessageWithErrorCheck(
data,
api_util.GetMessage('ModelDeploymentMonitoringJob', self._version))
model_monitoring_job_to_update.modelDeploymentMonitoringObjectiveConfigs = job_spec.modelDeploymentMonitoringObjectiveConfigs
update_mask.append('model_deployment_monitoring_objective_configs')
if args.feature_thresholds or args.feature_attribution_thresholds:
get_monitoring_job_req = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsGetRequest(
name=model_monitoring_job_ref.RelativeName())
model_monitoring_job = self._service.Get(get_monitoring_job_req)
model_monitoring_job_to_update.modelDeploymentMonitoringObjectiveConfigs = self._ConstructObjectiveConfigForUpdate(
model_monitoring_job, args.feature_thresholds,
args.feature_attribution_thresholds)
update_mask.append('model_deployment_monitoring_objective_configs')
if args.display_name:
model_monitoring_job_to_update.displayName = args.display_name
update_mask.append('display_name')
if args.emails:
model_monitoring_job_to_update.modelMonitoringAlertConfig = (
api_util.GetMessage('ModelMonitoringAlertConfig', self._version)(
emailAlertConfig=api_util.GetMessage(
'ModelMonitoringAlertConfigEmailAlertConfig', self._version
)(userEmails=args.emails)
)
)
update_mask.append('model_monitoring_alert_config.email_alert_config')
if args.anomaly_cloud_logging is not None:
if args.emails:
model_monitoring_job_to_update.modelMonitoringAlertConfig.enableLogging = (
args.anomaly_cloud_logging
)
else:
model_monitoring_job_to_update.modelMonitoringAlertConfig = (
api_util.GetMessage('ModelMonitoringAlertConfig', self._version)(
enableLogging=args.anomaly_cloud_logging
)
)
update_mask.append('model_monitoring_alert_config.enable_logging')
if args.notification_channels:
if args.emails or args.anomaly_cloud_logging is not None:
model_monitoring_job_to_update.modelMonitoringAlertConfig.notificationChannels = (
args.notification_channels
)
else:
model_monitoring_job_to_update.modelMonitoringAlertConfig = (
api_util.GetMessage('ModelMonitoringAlertConfig', self._version)(
notificationChannels=args.notification_channels
)
)
update_mask.append('model_monitoring_alert_config.notification_channels')
# sampling rate
if args.prediction_sampling_rate:
model_monitoring_job_to_update.loggingSamplingStrategy = api_util.GetMessage(
'SamplingStrategy', self._version)(
randomSampleConfig=api_util.GetMessage(
'SamplingStrategyRandomSampleConfig', self._version)(
sampleRate=args.prediction_sampling_rate))
update_mask.append('logging_sampling_strategy')
# schedule
if args.monitoring_frequency:
model_monitoring_job_to_update.modelDeploymentMonitoringScheduleConfig = api_util.GetMessage(
'ModelDeploymentMonitoringScheduleConfig', self._version)(
monitorInterval='{}s'.format(
six.text_type(3600 * int(args.monitoring_frequency))))
update_mask.append('model_deployment_monitoring_schedule_config')
if args.analysis_instance_schema:
model_monitoring_job_to_update.analysisInstanceSchemaUri = args.analysis_instance_schema
update_mask.append('analysis_instance_schema_uri')
if args.log_ttl:
model_monitoring_job_to_update.logTtl = '{}s'.format(
six.text_type(86400 * int(args.log_ttl)))
update_mask.append('log_ttl')
labels_update = self._ParseUpdateLabels(model_monitoring_job_ref, args)
if labels_update.needs_update:
model_monitoring_job_to_update.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
if self._version == constants.BETA_VERSION:
req = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsPatchRequest(
name=model_monitoring_job_ref.RelativeName(),
googleCloudAiplatformV1beta1ModelDeploymentMonitoringJob=model_monitoring_job_to_update,
updateMask=','.join(update_mask))
else:
req = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsPatchRequest(
name=model_monitoring_job_ref.RelativeName(),
googleCloudAiplatformV1ModelDeploymentMonitoringJob=model_monitoring_job_to_update,
updateMask=','.join(update_mask))
return self._service.Patch(req)
def Get(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsGetRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=None, region_ref=None):
return list_pager.YieldFromList(
self._service,
self.messages
.AiplatformProjectsLocationsModelDeploymentMonitoringJobsListRequest(
parent=region_ref.RelativeName()),
field='modelDeploymentMonitoringJobs',
batch_size_attribute='pageSize',
limit=limit)
def Delete(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsDeleteRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Delete(request)
def Pause(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsPauseRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Pause(request)
def Resume(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsResumeRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Resume(request)

View File

@@ -0,0 +1,895 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform models API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
class ModelsClient(object):
"""High-level client for the AI Platform models surface.
Attributes:
client: An instance of the given client, or the API client aiplatform of
Beta version.
messages: The messages module for the given client, or the API client
aiplatform of Beta version.
"""
def __init__(self, client=None, messages=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[constants.BETA_VERSION])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_models
def UploadV1Beta1(
self,
region_ref=None,
display_name=None,
description=None,
version_description=None,
artifact_uri=None,
container_image_uri=None,
container_command=None,
container_args=None,
container_env_vars=None,
container_ports=None,
container_grpc_ports=None,
container_predict_route=None,
container_health_route=None,
container_deployment_timeout_seconds=None,
container_shared_memory_size_mb=None,
container_startup_probe_exec=None,
container_startup_probe_period_seconds=None,
container_startup_probe_timeout_seconds=None,
container_health_probe_exec=None,
container_health_probe_period_seconds=None,
container_health_probe_timeout_seconds=None,
explanation_spec=None,
parent_model=None,
model_id=None,
version_aliases=None,
labels=None,
base_model_source=None,
):
"""Constructs, sends an UploadModel request and returns the LRO to be done.
Args:
region_ref: The resource reference for a given region. None if the region
reference is not provided.
display_name: The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description: The description of the Model.
version_description: The description of the Model version.
artifact_uri: The path to the directory containing the Model artifact and
any of its supporting files. Not present for AutoML Models.
container_image_uri: Immutable. URI of the Docker image to be used as the
custom container for serving predictions. This URI must identify an
image in Artifact Registry or Container Registry. Learn more about the
[container publishing requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#publishing), including
permissions requirements for the Vertex AI Service Agent. The container
image is ingested upon ModelService.UploadModel, stored internally, and
this original path is afterwards not used. To learn about the
requirements for the Docker image itself, see [Custom container
requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#). You can use the URI
to one of Vertex AI's [pre-built container images for
prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-
built-containers) in this field.
container_command: Specifies the command that runs when the container
starts. This overrides the container's [ENTRYPOINT](https://docs.docker.
com/engine/reference/builder/#entrypoint). Specify this field as an
array of executable and arguments, similar to a Docker `ENTRYPOINT`'s
"exec" form, not its "shell" form. If you do not specify this field,
then the container's `ENTRYPOINT` runs, in conjunction with the args
field or the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd), if
either exists. If this field is not specified and the container does not
have an `ENTRYPOINT`, then refer to the Docker documentation about [how
`CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). If you specify this field, then you
can also specify the `args` field to provide additional arguments for
this command. However, if you specify this field, then the container's
`CMD` is ignored. See the [Kubernetes documentation about how the
`command` and `args` fields interact with a container's `ENTRYPOINT` and
`CMD`](https://kubernetes.io/docs/tasks/inject-data-application/define-
command-argument-container/#notes). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `command` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_args: Specifies arguments for the command that runs when the
container starts. This overrides the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd). Specify
this field as an array of executable and arguments, similar to a Docker
`CMD`'s "default parameters" form. If you don't specify this field but
do specify the command field, then the command from the `command` field
runs without any additional arguments. See the [Kubernetes documentation
about how the `command` and `args` fields interact with a container's
`ENTRYPOINT` and `CMD`](https://kubernetes.io/docs/tasks/inject-data-
application/define-command-argument-container/#notes). If you don't
specify this field and don't specify the `command` field, then the
container's
[`ENTRYPOINT`](https://docs.docker.com/engine/reference/builder/#cmd)
and `CMD` determine what runs based on their default behavior. See the
Docker documentation about [how `CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `args` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core)..
container_env_vars: List of environment variables to set in the container.
After the container starts running, code running in the container can
read these environment variables. Additionally, the command and args
fields can reference these variables. Later entries in this list can
also reference earlier entries. For example, the following example sets
the variable `VAR_2` to have the value `foo bar`: ```json [ { "name":
"VAR_1", "value": "foo" }, { "name": "VAR_2", "value": "$(VAR_1) bar" }
] ``` If you switch the order of the variables in the example, then the
expansion does not occur. This field corresponds to the `env` field of
the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_ports: List of ports to expose from the container. Vertex AI
sends any http prediction requests that it receives to the first port on
this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, it defaults to following value: ```json [ { "containerPort":
8080 } ] ``` Vertex AI does not use ports other than the first one
listed. This field corresponds to the `ports` field of the Kubernetes
Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_grpc_ports: List of ports to expose from the container. Vertex
AI sends any grpc prediction requests that it receives to the first port
on this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, gRPC requests to the container will be disabled. Vertex AI
does not use ports other than the first one listed. This field
corresponds to the `ports` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_predict_route: HTTP path on the container to send prediction
requests to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the container's IP
address and port. Vertex AI then returns the container's response in the
API response. For example, if you set this field to `/foo`, then when
Vertex AI receives a prediction request, it forwards the request body in
a POST request to the `/foo` path on the port of your container
specified by the first value of this `ModelContainerSpec`'s ports field.
If you don't specify this field, it defaults to the following value when
you deploy this Model to an Endpoint:
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict The
placeholders in this value are replaced as follows: * ENDPOINT: The last
segment (following `endpoints/`)of the Endpoint.name][] field of the
Endpoint where this Model has been deployed. (Vertex AI makes this value
available to your container code as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_health_route: HTTP path on the container to send health checks
to. Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is healthy.
Read more about [health checks](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#health). For example,
if you set this field to `/bar`, then Vertex AI intermittently sends a
GET request to the `/bar` path on the port of your container specified
by the first value of this `ModelContainerSpec`'s ports field. If you
don't specify this field, it defaults to the following value when you
deploy this Model to an Endpoint: /v1/endpoints/ENDPOINT/deployedModels/
DEPLOYED_MODEL:predict The placeholders in this value are replaced as
follows * ENDPOINT: The last segment (following `endpoints/`)of the
Endpoint.name][] field of the Endpoint where this Model has been
deployed. (Vertex AI makes this value available to your container code
as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_deployment_timeout_seconds (int): Deployment timeout in seconds.
container_shared_memory_size_mb (int): The amount of the VM memory to
reserve as the shared memory for the model in megabytes.
container_startup_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by startup probe. An example of this argument would be
["cat", "/tmp/healthy"]
container_startup_probe_period_seconds (int): How often (in seconds) to
perform the startup probe. Default to 10 seconds. Minimum value is 1.
container_startup_probe_timeout_seconds (int): Number of seconds after
which the startup probe times out. Defaults to 1 second. Minimum value
is 1.
container_health_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by health probe. An example of this argument would be ["cat",
"/tmp/healthy"]
container_health_probe_period_seconds (int): How often (in seconds) to
perform the health probe. Default to 10 seconds. Minimum value is 1.
container_health_probe_timeout_seconds (int): Number of seconds after
which the health probe times out. Defaults to 1 second. Minimum value is
1.
explanation_spec: The default explanation specification for this Model.
The Model can be used for requesting explanation after being deployed if
it is populated. The Model can be used for batch explanation if it is
populated. All fields of the explanation_spec can be overridden by
explanation_spec of DeployModelRequest.deployed_model, or
explanation_spec of BatchPredictionJob. If the default explanation
specification is not set for this Model, this Model can still be used
for requesting explanation by setting explanation_spec of
DeployModelRequest.deployed_model and for batch explanation by setting
explanation_spec of BatchPredictionJob.
parent_model: The resource name of the model into which to upload the
version. Only specify this field when uploading a new version.
model_id: The ID to use for the uploaded Model, which will become the
final component of the model resource name. This value may be up to 63
characters, and valid characters are `[a-z0-9_-]`. The first character
cannot be a number or hyphen..
version_aliases: User provided version aliases so that a model version can
be referenced via alias (i.e. projects/{project}/locations/{location}/mo
dels/{model_id}@{version_alias} instead of auto-generated version id
(i.e.
projects/{project}/locations/{location}/models/{model_id}@{version_id}).
The format is a-z{0,126}[a-z0-9] to distinguish from version_id. A
default version alias will be created for the first version of the
model, and there must be exactly one default version alias for a model.
labels: The labels with user-defined metadata to organize your Models.
Label keys and values can be no longer than 64 characters (Unicode
codepoints), can only contain lowercase letters, numeric characters,
underscores and dashes. International characters are allowed. See
https://goo.gl/xmQnxf for more information and examples of labels.
base_model_source: A GoogleCloudAiplatformV1beta1ModelBaseModelSource
object that indicates the source of the model. Currently it only
supports specifying the Model Garden models and Generative AI Studio
models.
Returns:
Response from calling upload model with given request arguments.
"""
container_spec = (
self.messages.GoogleCloudAiplatformV1beta1ModelContainerSpec(
healthRoute=container_health_route,
imageUri=container_image_uri,
predictRoute=container_predict_route,
)
)
if container_command:
container_spec.command = container_command
if container_args:
container_spec.args = container_args
if container_env_vars:
container_spec.env = [
self.messages.GoogleCloudAiplatformV1beta1EnvVar(
name=k, value=container_env_vars[k]) for k in container_env_vars
]
if container_ports:
container_spec.ports = [
self.messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_ports
]
if container_grpc_ports:
container_spec.grpcPorts = [
self.messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_grpc_ports
]
if container_deployment_timeout_seconds:
container_spec.deploymentTimeout = (
str(container_deployment_timeout_seconds) + 's'
)
if container_shared_memory_size_mb:
container_spec.sharedMemorySizeMb = container_shared_memory_size_mb
if (
container_startup_probe_exec
or container_startup_probe_period_seconds
or container_startup_probe_timeout_seconds
):
startup_probe_exec = None
if container_startup_probe_exec:
startup_probe_exec = (
self.messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_startup_probe_exec
)
)
container_spec.startupProbe = (
self.messages.GoogleCloudAiplatformV1beta1Probe(
exec_=startup_probe_exec,
periodSeconds=container_startup_probe_period_seconds,
timeoutSeconds=container_startup_probe_timeout_seconds,
)
)
if (
container_health_probe_exec
or container_health_probe_period_seconds
or container_health_probe_timeout_seconds
):
health_probe_exec = None
if container_health_probe_exec:
health_probe_exec = (
self.messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_health_probe_exec
)
)
container_spec.healthProbe = (
self.messages.GoogleCloudAiplatformV1beta1Probe(
exec_=health_probe_exec,
periodSeconds=container_health_probe_period_seconds,
timeoutSeconds=container_health_probe_timeout_seconds,
)
)
model = self.messages.GoogleCloudAiplatformV1beta1Model(
artifactUri=artifact_uri,
containerSpec=container_spec,
description=description,
versionDescription=version_description,
displayName=display_name,
explanationSpec=explanation_spec,
baseModelSource=base_model_source,
)
if version_aliases:
model.versionAliases = version_aliases
if labels:
additional_properties = []
for key, value in sorted(labels.items()):
additional_properties.append(model.LabelsValue().AdditionalProperty(
key=key, value=value))
model.labels = model.LabelsValue(
additionalProperties=additional_properties)
return self._service.Upload(
self.messages.AiplatformProjectsLocationsModelsUploadRequest(
parent=region_ref.RelativeName(),
googleCloudAiplatformV1beta1UploadModelRequest=self.messages
.GoogleCloudAiplatformV1beta1UploadModelRequest(
model=model,
parentModel=parent_model,
modelId=model_id)))
def UploadV1(self,
region_ref=None,
display_name=None,
description=None,
version_description=None,
artifact_uri=None,
container_image_uri=None,
container_command=None,
container_args=None,
container_env_vars=None,
container_ports=None,
container_grpc_ports=None,
container_predict_route=None,
container_health_route=None,
container_deployment_timeout_seconds=None,
container_shared_memory_size_mb=None,
container_startup_probe_exec=None,
container_startup_probe_period_seconds=None,
container_startup_probe_timeout_seconds=None,
container_health_probe_exec=None,
container_health_probe_period_seconds=None,
container_health_probe_timeout_seconds=None,
explanation_spec=None,
parent_model=None,
model_id=None,
version_aliases=None,
labels=None):
"""Constructs, sends an UploadModel request and returns the LRO to be done.
Args:
region_ref: The resource reference for a given region. None if the region
reference is not provided.
display_name: The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description: The description of the Model.
version_description: The description of the Model version.
artifact_uri: The path to the directory containing the Model artifact and
any of its supporting files. Not present for AutoML Models.
container_image_uri: Immutable. URI of the Docker image to be used as the
custom container for serving predictions. This URI must identify an
image in Artifact Registry or Container Registry. Learn more about the
[container publishing requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#publishing), including
permissions requirements for the Vertex AI Service Agent. The container
image is ingested upon ModelService.UploadModel, stored internally, and
this original path is afterwards not used. To learn about the
requirements for the Docker image itself, see [Custom container
requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#). You can use the URI
to one of Vertex AI's [pre-built container images for
prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-
built-containers) in this field.
container_command: Specifies the command that runs when the container
starts. This overrides the container's [ENTRYPOINT](https://docs.docker.
com/engine/reference/builder/#entrypoint). Specify this field as an
array of executable and arguments, similar to a Docker `ENTRYPOINT`'s
"exec" form, not its "shell" form. If you do not specify this field,
then the container's `ENTRYPOINT` runs, in conjunction with the args
field or the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd), if
either exists. If this field is not specified and the container does not
have an `ENTRYPOINT`, then refer to the Docker documentation about [how
`CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). If you specify this field, then you
can also specify the `args` field to provide additional arguments for
this command. However, if you specify this field, then the container's
`CMD` is ignored. See the [Kubernetes documentation about how the
`command` and `args` fields interact with a container's `ENTRYPOINT` and
`CMD`](https://kubernetes.io/docs/tasks/inject-data-application/define-
command-argument-container/#notes). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `command` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_args: Specifies arguments for the command that runs when the
container starts. This overrides the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd). Specify
this field as an array of executable and arguments, similar to a Docker
`CMD`'s "default parameters" form. If you don't specify this field but
do specify the command field, then the command from the `command` field
runs without any additional arguments. See the [Kubernetes documentation
about how the `command` and `args` fields interact with a container's
`ENTRYPOINT` and `CMD`](https://kubernetes.io/docs/tasks/inject-data-
application/define-command-argument-container/#notes). If you don't
specify this field and don't specify the `command` field, then the
container's
[`ENTRYPOINT`](https://docs.docker.com/engine/reference/builder/#cmd)
and `CMD` determine what runs based on their default behavior. See the
Docker documentation about [how `CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `args` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core)..
container_env_vars: List of environment variables to set in the container.
After the container starts running, code running in the container can
read these environment variables. Additionally, the command and args
fields can reference these variables. Later entries in this list can
also reference earlier entries. For example, the following example sets
the variable `VAR_2` to have the value `foo bar`: ```json [ { "name":
"VAR_1", "value": "foo" }, { "name": "VAR_2", "value": "$(VAR_1) bar" }
] ``` If you switch the order of the variables in the example, then the
expansion does not occur. This field corresponds to the `env` field of
the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_ports: List of ports to expose from the container. Vertex AI
sends any http prediction requests that it receives to the first port on
this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, it defaults to following value: ```json [ { "containerPort":
8080 } ] ``` Vertex AI does not use ports other than the first one
listed. This field corresponds to the `ports` field of the Kubernetes
Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_grpc_ports: List of ports to expose from the container. Vertex
AI sends any grpc prediction requests that it receives to the first port
on this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, gRPC requests to the container will be disabled. Vertex AI
does not use ports other than the first one listed. This field
corresponds to the `ports` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_predict_route: HTTP path on the container to send prediction
requests to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the container's IP
address and port. Vertex AI then returns the container's response in the
API response. For example, if you set this field to `/foo`, then when
Vertex AI receives a prediction request, it forwards the request body in
a POST request to the `/foo` path on the port of your container
specified by the first value of this `ModelContainerSpec`'s ports field.
If you don't specify this field, it defaults to the following value when
you deploy this Model to an Endpoint:
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict The
placeholders in this value are replaced as follows: * ENDPOINT: The last
segment (following `endpoints/`)of the Endpoint.name][] field of the
Endpoint where this Model has been deployed. (Vertex AI makes this value
available to your container code as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_health_route: HTTP path on the container to send health checks
to. Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is healthy.
Read more about [health checks](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#health). For example,
if you set this field to `/bar`, then Vertex AI intermittently sends a
GET request to the `/bar` path on the port of your container specified
by the first value of this `ModelContainerSpec`'s ports field. If you
don't specify this field, it defaults to the following value when you
deploy this Model to an Endpoint: /v1/endpoints/ENDPOINT/deployedModels/
DEPLOYED_MODEL:predict The placeholders in this value are replaced as
follows * ENDPOINT: The last segment (following `endpoints/`)of the
Endpoint.name][] field of the Endpoint where this Model has been
deployed. (Vertex AI makes this value available to your container code
as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_deployment_timeout_seconds (int): Deployment timeout in seconds.
container_shared_memory_size_mb (int): The amount of the VM memory to
reserve as the shared memory for the model in megabytes.
container_startup_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by startup probe. An example of this argument would be
["cat", "/tmp/healthy"]
container_startup_probe_period_seconds (int): How often (in seconds) to
perform the startup probe. Default to 10 seconds. Minimum value is 1.
container_startup_probe_timeout_seconds (int): Number of seconds after
which the startup probe times out. Defaults to 1 second. Minimum value
is 1.
container_health_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by health probe. An example of this argument would be ["cat",
"/tmp/healthy"]
container_health_probe_period_seconds (int): How often (in seconds) to
perform the health probe. Default to 10 seconds. Minimum value is 1.
container_health_probe_timeout_seconds (int): Number of seconds after
which the health probe times out. Defaults to 1 second. Minimum value is
1.
explanation_spec: The default explanation specification for this Model.
The Model can be used for requesting explanation after being deployed if
it is populated. The Model can be used for batch explanation if it is
populated. All fields of the explanation_spec can be overridden by
explanation_spec of DeployModelRequest.deployed_model, or
explanation_spec of BatchPredictionJob. If the default explanation
specification is not set for this Model, this Model can still be used
for requesting explanation by setting explanation_spec of
DeployModelRequest.deployed_model and for batch explanation by setting
explanation_spec of BatchPredictionJob.
parent_model: The resource name of the model into which to upload the
version. Only specify this field when uploading a new version.
model_id: The ID to use for the uploaded Model, which will become the
final component of the model resource name. This value may be up to 63
characters, and valid characters are `[a-z0-9_-]`. The first character
cannot be a number or hyphen..
version_aliases: User provided version aliases so that a model version can
be referenced via alias (i.e. projects/{project}/locations/{location}/mo
dels/{model_id}@{version_alias} instead of auto-generated version id
(i.e.
projects/{project}/locations/{location}/models/{model_id}@{version_id}).
The format is a-z{0,126}[a-z0-9] to distinguish from version_id. A
default version alias will be created for the first version of the
model, and there must be exactly one default version alias for a model.
labels: The labels with user-defined metadata to organize your Models.
Label keys and values can be no longer than 64 characters (Unicode
codepoints), can only contain lowercase letters, numeric characters,
underscores and dashes. International characters are allowed. See
https://goo.gl/xmQnxf for more information and examples of labels.
Returns:
Response from calling upload model with given request arguments.
"""
container_spec = self.messages.GoogleCloudAiplatformV1ModelContainerSpec(
healthRoute=container_health_route,
imageUri=container_image_uri,
predictRoute=container_predict_route)
if container_command:
container_spec.command = container_command
if container_args:
container_spec.args = container_args
if container_env_vars:
container_spec.env = [
self.messages.GoogleCloudAiplatformV1EnvVar(
name=k, value=container_env_vars[k]) for k in container_env_vars
]
if container_ports:
container_spec.ports = [
self.messages.GoogleCloudAiplatformV1Port(containerPort=port)
for port in container_ports
]
if container_grpc_ports:
container_spec.grpcPorts = [
self.messages.GoogleCloudAiplatformV1Port(containerPort=port)
for port in container_grpc_ports
]
if container_deployment_timeout_seconds:
container_spec.deploymentTimeout = (
str(container_deployment_timeout_seconds) + 's'
)
if container_shared_memory_size_mb:
container_spec.sharedMemorySizeMb = container_shared_memory_size_mb
if (
container_startup_probe_exec
or container_startup_probe_period_seconds
or container_startup_probe_timeout_seconds
):
startup_probe_exec = None
if container_startup_probe_exec:
startup_probe_exec = (
self.messages.GoogleCloudAiplatformV1ProbeExecAction(
command=container_startup_probe_exec
)
)
container_spec.startupProbe = (
self.messages.GoogleCloudAiplatformV1Probe(
exec_=startup_probe_exec,
periodSeconds=container_startup_probe_period_seconds,
timeoutSeconds=container_startup_probe_timeout_seconds,
)
)
if (
container_health_probe_exec
or container_health_probe_period_seconds
or container_health_probe_timeout_seconds
):
health_probe_exec = None
if container_health_probe_exec:
health_probe_exec = (
self.messages.GoogleCloudAiplatformV1ProbeExecAction(
command=container_health_probe_exec
)
)
container_spec.healthProbe = (
self.messages.GoogleCloudAiplatformV1Probe(
exec_=health_probe_exec,
periodSeconds=container_health_probe_period_seconds,
timeoutSeconds=container_health_probe_timeout_seconds,
)
)
model = self.messages.GoogleCloudAiplatformV1Model(
artifactUri=artifact_uri,
containerSpec=container_spec,
description=description,
versionDescription=version_description,
displayName=display_name,
explanationSpec=explanation_spec)
if version_aliases:
model.versionAliases = version_aliases
if labels:
additional_properties = []
for key, value in sorted(labels.items()):
additional_properties.append(model.LabelsValue().AdditionalProperty(
key=key, value=value))
model.labels = model.LabelsValue(
additionalProperties=additional_properties)
return self._service.Upload(
self.messages.AiplatformProjectsLocationsModelsUploadRequest(
parent=region_ref.RelativeName(),
googleCloudAiplatformV1UploadModelRequest=self.messages
.GoogleCloudAiplatformV1UploadModelRequest(
model=model,
parentModel=parent_model,
modelId=model_id)))
def Get(self, model_ref):
"""Gets (describe) the given model.
Args:
model_ref: The resource reference for a given model. None if model
resource reference is not provided.
Returns:
Response from calling get model with request containing given model.
"""
request = self.messages.AiplatformProjectsLocationsModelsGetRequest(
name=model_ref.RelativeName())
return self._service.Get(request)
def Delete(self, model_ref):
"""Deletes the given model.
Args:
model_ref: The resource reference for a given model. None if model
resource reference is not provided.
Returns:
Response from calling delete model with request containing given model.
"""
request = self.messages.AiplatformProjectsLocationsModelsDeleteRequest(
name=model_ref.RelativeName())
return self._service.Delete(request)
def DeleteVersion(self, model_version_ref):
"""Deletes the given model version.
Args:
model_version_ref: The resource reference for a given model version.
Returns:
Response from calling delete version with request containing given model
version.
"""
request = (
self.messages.AiplatformProjectsLocationsModelsDeleteVersionRequest(
name=model_version_ref.RelativeName()
)
)
return self._service.DeleteVersion(request)
def List(self, limit=None, region_ref=None):
"""List all models in the given region.
Args:
limit: int, The maximum number of records to yield. None if all available
records should be yielded.
region_ref: The resource reference for a given region. None if the region
reference is not provided.
Returns:
Response from calling list models with request containing given models
and limit.
"""
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsModelsListRequest(
parent=region_ref.RelativeName()),
field='models',
batch_size_attribute='pageSize',
limit=limit)
def ListVersion(self, model_ref=None, limit=None):
"""List all model versions of the given model.
Args:
model_ref: The resource reference for a given model. None if model
resource reference is not provided.
limit: int, The maximum number of records to yield. None if all available
records should be yielded.
Returns:
Response from calling list model versions with request containing given
model and limit.
"""
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsModelsListVersionsRequest(
name=model_ref.RelativeName()),
method='ListVersions',
field='models',
batch_size_attribute='pageSize',
limit=limit)
def CopyV1Beta1(self,
destination_region_ref=None,
source_model=None,
kms_key_name=None,
destination_model_id=None,
destination_parent_model=None):
"""Copies the given source model into specified location.
The source model is copied into specified location (including cross-region)
either as a new model or a new model version under given parent model.
Args:
destination_region_ref: the resource reference to the location into which
to copy the Model.
source_model: The resource name of the Model to copy.
kms_key_name: The KMS key name for specifying encryption spec.
destination_model_id: The destination model resource name to copy the
model into.
destination_parent_model: The destination parent model to copy the model
as a model version into.
Returns:
Response from calling copy model.
"""
encryption_spec = None
if kms_key_name:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
kmsKeyName=kms_key_name
)
)
request = self.messages.AiplatformProjectsLocationsModelsCopyRequest(
parent=destination_region_ref.RelativeName(),
googleCloudAiplatformV1beta1CopyModelRequest=self.messages
.GoogleCloudAiplatformV1beta1CopyModelRequest(
sourceModel=source_model,
encryptionSpec=encryption_spec,
parentModel=destination_parent_model,
modelId=destination_model_id))
return self._service.Copy(request)
def CopyV1(self,
destination_region_ref=None,
source_model=None,
kms_key_name=None,
destination_model_id=None,
destination_parent_model=None):
"""Copies the given source model into specified location.
The source model is copied into specified location (including cross-region)
either as a new model or a new model version under given parent model.
Args:
destination_region_ref: the resource reference to the location into which
to copy the Model.
source_model: The resource name of the Model to copy.
kms_key_name: The name of the KMS key to use for model encryption.
destination_model_id: Optional. Thew custom ID to be used as the resource
name of the new model. This value may be up to 63 characters, and valid
characters are `[a-z0-9_-]`. The first character cannot be a number or
hyphen.
destination_parent_model: The destination parent model to copy the model
as a model version into.
Returns:
Response from calling copy model.
"""
encryption_spec = None
if kms_key_name:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
kmsKeyName=kms_key_name
)
)
request = self.messages.AiplatformProjectsLocationsModelsCopyRequest(
parent=destination_region_ref.RelativeName(),
googleCloudAiplatformV1CopyModelRequest=self.messages
.GoogleCloudAiplatformV1CopyModelRequest(
sourceModel=source_model,
encryptionSpec=encryption_spec,
parentModel=destination_parent_model,
modelId=destination_model_id))
return self._service.Copy(request)

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with long-running operations (simple uri)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.command_lib.ai import constants
def GetClientInstance(api_version=None, no_http=False):
return apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME, api_version, no_http=no_http)
class AiPlatformOperationPoller(waiter.CloudOperationPoller):
"""Poller for AI Platform operations API.
This is necessary because the core operations library doesn't directly support
simple_uri.
"""
def __init__(self, client):
self.client = client
super(AiPlatformOperationPoller, self).__init__(
self.client.client.projects_locations_operations,
self.client.client.projects_locations_operations)
def Poll(self, operation_ref):
return self.client.Get(operation_ref)
def GetResult(self, operation):
return operation
class OperationsClient(object):
"""High-level client for the AI Platform operations surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or GetClientInstance(
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
def Get(self, operation_ref):
return self.client.projects_locations_operations.Get(
self.messages.AiplatformProjectsLocationsOperationsGetRequest(
name=operation_ref.RelativeName()))
def WaitForOperation(
self, operation, operation_ref, message=None, max_wait_ms=1800000
):
"""Wait until the operation is complete or times out.
Args:
operation: The operation resource to wait on
operation_ref: The operation reference to the operation resource. It's the
result by calling resources.REGISTRY.Parse
message: str, the message to print while waiting.
max_wait_ms: int, number of ms to wait before raising WaitException.
Returns:
The operation resource when it has completed
Raises:
OperationTimeoutError: when the operation polling times out
OperationError: when the operation completed with an error
"""
poller = AiPlatformOperationPoller(self)
if poller.IsDone(operation):
return operation
if message is None:
message = 'Waiting for operation [{}]'.format(operation_ref.Name())
return waiter.WaitFor(
poller, operation_ref, message, max_wait_ms=max_wait_ms
)

View File

@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying Vertex AI Persistent Resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core.console import console_io
class PersistentResourcesClient(object):
"""Client used for interacting with the PersistentResource endpoint."""
def __init__(self, version=constants.GA_VERSION):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._version = version
self._service = client.projects_locations_persistentResources
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def GetMessage(self, message_name):
"""Returns the API message class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def PersistentResourceMessage(self):
"""Returns the PersistentResource message."""
return self.GetMessage('PersistentResource')
def Create(self,
parent,
resource_pools,
persistent_resource_id,
display_name=None,
kms_key_name=None,
labels=None,
network=None,
enable_custom_service_account=False,
service_account=None):
"""Constructs a request and sends it to the endpoint to create a persistent resource.
Args:
parent: str, The project resource path of the persistent resource to
create.
resource_pools: The PersistentResource message instance for the
creation request.
persistent_resource_id: The PersistentResource id for the creation
request.
display_name: str, The display name of the persistent resource to create.
kms_key_name: A customer-managed encryption key to use for the persistent
resource.
labels: LabelValues, map-like user-defined metadata to organize the
resource.
network: Network to peer with the PersistentResource
enable_custom_service_account: Whether or not to enable this Persistent
Resource to use a custom service account.
service_account: A service account (email address string) to use for
creating the Persistent Resource.
Returns:
A PersistentResource message instance created.
"""
persistent_resource = self.PersistentResourceMessage()(
displayName=display_name, resourcePools=resource_pools)
if kms_key_name is not None:
persistent_resource.encryptionSpec = self.GetMessage('EncryptionSpec')(
kmsKeyName=kms_key_name)
if labels:
persistent_resource.labels = labels
if network:
persistent_resource.network = network
if enable_custom_service_account:
persistent_resource.resourceRuntimeSpec = (
self.GetMessage('ResourceRuntimeSpec')(
serviceAccountSpec=self.GetMessage('ServiceAccountSpec')(
enableCustomServiceAccount=True,
serviceAccount=service_account)))
if self._version == constants.GA_VERSION:
return self._service.Create(
self._messages.AiplatformProjectsLocationsPersistentResourcesCreateRequest(
parent=parent,
googleCloudAiplatformV1PersistentResource=persistent_resource,
persistentResourceId=persistent_resource_id,
)
)
return self._service.Create(
self._messages.AiplatformProjectsLocationsPersistentResourcesCreateRequest(
parent=parent,
googleCloudAiplatformV1beta1PersistentResource=persistent_resource,
persistentResourceId=persistent_resource_id,
)
)
def List(self, limit=None, region=None):
"""Constructs a list request and sends it to the Persistent Resources endpoint.
Args:
limit: How many items to return in the list
region: Which region to list resources from
Returns:
A Persistent Resource list response message.
"""
return list_pager.YieldFromList(
self._service,
self._messages.AiplatformProjectsLocationsPersistentResourcesListRequest(
parent=region
),
field='persistentResources',
batch_size_attribute='pageSize',
limit=limit,
)
def Get(self, name):
request = (self._messages
.AiplatformProjectsLocationsPersistentResourcesGetRequest(
name=name))
return self._service.Get(request)
def Delete(self, name):
request = self._messages.AiplatformProjectsLocationsPersistentResourcesDeleteRequest(
name=name
)
return self._service.Delete(request)
def Reboot(self, name):
request = self._messages.AiplatformProjectsLocationsPersistentResourcesRebootRequest(
name=name
)
return self._service.Reboot(request)
def ImportResourceMessage(self, yaml_file, message_name):
"""Import a messages class instance typed by name from a YAML file."""
data = console_io.ReadFromFileOrStdin(yaml_file, binary=False)
message_type = self.GetMessage(message_name)
return export_util.Import(message_type=message_type, stream=data)

View File

@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is used to get the client instance and messages module for GKE recommender."""
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
VERSION_MAP = {
base.ReleaseTrack.ALPHA: 'v1alpha1',
base.ReleaseTrack.GA: 'v1',
}
HTTP_ERROR_FORMAT = (
'ResponseError: code={status_code}, message={status_message}'
)
# The messages module can also be accessed from client.MESSAGES_MODULE
def GetMessagesModule(release_track=base.ReleaseTrack.GA):
api_version = VERSION_MAP.get(release_track)
return apis.GetMessagesModule('gkerecommender', api_version)
def GetClientInstance(release_track=base.ReleaseTrack.GA):
api_version = VERSION_MAP.get(release_track)
return apis.GetClientInstance('gkerecommender', api_version)

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying serverless ray jobs in AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core.console import console_io
class ServerlessRayJobsClient(object):
"""Client used for interacting with Serverless Ray Jobs endpoint."""
def __init__(self, version=constants.GA_VERSION):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._version = version
self._service = client.projects_locations_serverlessRayJobs
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def GetMessage(self, message_name):
"""Returns the API message class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def ServerlessRayJobMessage(self):
"""Retures the Serverless Ray Jobs resource message."""
return self.GetMessage('ServerlessRayJob')
def Create(self,
parent,
job_spec,
display_name=None,
labels=None):
"""Constructs a request and sends it to the endpoint to create a serverless ray job instance.
Args:
parent: str, The project resource path of the serverless ray job to
create.
job_spec: The ServerlessRayJobSpec message instance for the job creation
request.
display_name: str, The display name of the serverless ray job to create.
labels: LabelValues, map-like user-defined metadata to organize the
serverless ray job.
Returns:
A ServerlessRayJob message instance created.
"""
serverless_ray_job = self.ServerlessRayJobMessage()(
displayName=display_name, jobSpec=job_spec
)
if labels:
serverless_ray_job.labels = labels
# TODO(b/390679825): Add V1 version support when Serverless Ray Jobs API is
# GA ready.
return self._service.Create(
self._messages.AiplatformProjectsLocationsServerlessRayJobsCreateRequest(
parent=parent,
googleCloudAiplatformV1beta1ServerlessRayJob=serverless_ray_job,
)
)
def List(self, limit=None, region=None):
return list_pager.YieldFromList(
self._service,
self._messages.AiplatformProjectsLocationsServerlessRayJobsListRequest(
parent=region
),
field='serverlessRayJobs',
batch_size_attribute='pageSize',
limit=limit,
)
def Get(self, name):
request = (
self._messages.AiplatformProjectsLocationsServerlessRayJobsGetRequest(
name=name
)
)
return self._service.Get(request)
def Cancel(self, name):
request = self._messages.AiplatformProjectsLocationsServerlessRayJobsCancelRequest(
name=name
)
return self._service.Cancel(request)
def ImportResourceMessage(self, yaml_file, message_name):
"""Import a messages class instance typed by name from a YAML file."""
data = console_io.ReadFromFileOrStdin(yaml_file, binary=False)
message_type = self.GetMessage(message_name)
return export_util.Import(message_type=message_type, stream=data)

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard experiments API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
class TensorboardExperimentsClient(object):
"""High-level client for the AI Platform Tensorboard experiment surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards_experiments
self._version = version
def Create(self, tensorboard_ref, args):
return self.CreateBeta(tensorboard_ref, args)
def CreateBeta(self, tensorboard_ref, args):
"""Create a new Tensorboard experiment."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1beta1TensorboardExperiment
.LabelsValue)
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsCreateRequest(
parent=tensorboard_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardExperiment=self.messages
.GoogleCloudAiplatformV1beta1TensorboardExperiment(
displayName=args.display_name,
description=args.description,
labels=labels),
tensorboardExperimentId=args.tensorboard_experiment_id)
return self._service.Create(request)
def List(self, tensorboard_ref, limit=1000, page_size=50, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsListRequest(
parent=tensorboard_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboardExperiments',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Get(self, tensorboard_exp_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsGetRequest(
name=tensorboard_exp_ref.RelativeName())
return self._service.Get(request)
def Delete(self, tensorboard_exp_ref):
request = (
self.messages
.AiplatformProjectsLocationsTensorboardsExperimentsDeleteRequest(
name=tensorboard_exp_ref.RelativeName()))
return self._service.Delete(request)
def Patch(self, tensorboard_exp_ref, args):
return self.PatchBeta(tensorboard_exp_ref, args)
def PatchBeta(self, tensorboard_exp_ref, args):
"""Update a Tensorboard experiment."""
tensorboard_exp = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardExperiment())
update_mask = []
def GetLabels():
return self.Get(tensorboard_exp_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1beta1TensorboardExperiment
.LabelsValue, GetLabels)
if labels_update.needs_update:
tensorboard_exp.labels = labels_update.labels
update_mask.append('labels')
if args.display_name is not None:
tensorboard_exp.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard_exp.description = args.description
update_mask.append('description')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsPatchRequest(
name=tensorboard_exp_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardExperiment=tensorboard_exp,
updateMask=','.join(update_mask))
return self._service.Patch(request)

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard runs API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
class TensorboardRunsClient(object):
"""High-level client for the AI Platform Tensorboard run surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards_experiments_runs
self._version = version
def Create(self, tensorboard_exp_ref, args):
return self.CreateBeta(tensorboard_exp_ref, args)
def CreateBeta(self, tensorboard_exp_ref, args):
"""Create a new Tensorboard run."""
labels = labels_util.ParseCreateArgs(
args,
self.messages.GoogleCloudAiplatformV1beta1TensorboardRun.LabelsValue)
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsCreateRequest(
parent=tensorboard_exp_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardRun=self.messages
.GoogleCloudAiplatformV1beta1TensorboardRun(
displayName=args.display_name,
description=args.description,
labels=labels),
tensorboardRunId=args.tensorboard_run_id)
return self._service.Create(request)
def List(self, tensorboard_exp_ref, limit=1000, page_size=50, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsListRequest(
parent=tensorboard_exp_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboardRuns',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Get(self, tensorboard_run_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsGetRequest(
name=tensorboard_run_ref.RelativeName())
return self._service.Get(request)
def Delete(self, tensorboard_run_ref):
request = (
self.messages
.AiplatformProjectsLocationsTensorboardsExperimentsRunsDeleteRequest(
name=tensorboard_run_ref.RelativeName()))
return self._service.Delete(request)
def Patch(self, tensorboard_run_ref, args):
return self.PatchBeta(tensorboard_run_ref, args)
def PatchBeta(self, tensorboard_run_ref, args):
"""Update a Tensorboard run."""
tensorboard_run = self.messages.GoogleCloudAiplatformV1beta1TensorboardRun()
update_mask = []
def GetLabels():
return self.Get(tensorboard_run_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args,
self.messages.GoogleCloudAiplatformV1beta1TensorboardRun.LabelsValue,
GetLabels)
if labels_update.needs_update:
tensorboard_run.labels = labels_update.labels
update_mask.append('labels')
if args.display_name is not None:
tensorboard_run.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard_run.description = args.description
update_mask.append('description')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsPatchRequest(
name=tensorboard_run_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardRun=tensorboard_run,
updateMask=','.join(update_mask))
return self._service.Patch(request)

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard time series API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
def GetMessagesModule(version=constants.BETA_VERSION):
return apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
class TensorboardTimeSeriesClient(object):
"""High-level client for the AI Platform Tensorboard time series surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards_experiments_runs_timeSeries
self._version = version
def Create(self, tensorboard_run_ref, args):
return self.CreateBeta(tensorboard_run_ref, args)
def CreateBeta(self, tensorboard_run_ref, args):
"""Create a new Tensorboard time series."""
if args.type == 'scalar':
value_type = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum.SCALAR)
elif args.type == 'blob-sequence':
value_type = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum.BLOB_SEQUENCE)
else:
value_type = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum.TENSOR)
if args.plugin_data is None:
plugin_data = ''
else:
plugin_data = args.plugin_data
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesCreateRequest(
parent=tensorboard_run_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardTimeSeries=self.messages
.GoogleCloudAiplatformV1beta1TensorboardTimeSeries(
displayName=args.display_name,
description=args.description,
valueType=value_type,
pluginName=args.plugin_name,
pluginData=bytes(plugin_data, encoding='utf8')))
return self._service.Create(request)
def List(self, tensorboard_run_ref, limit=1000, page_size=50, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesListRequest(
parent=tensorboard_run_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboardTimeSeries',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Get(self, tensorboard_time_series_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesGetRequest(
name=tensorboard_time_series_ref.RelativeName())
return self._service.Get(request)
def Delete(self, tensorboard_time_series_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesDeleteRequest(
name=tensorboard_time_series_ref.RelativeName())
return self._service.Delete(request)
def Patch(self, tensorboard_time_series_ref, args):
return self.PatchBeta(tensorboard_time_series_ref, args)
def PatchBeta(self, tensorboard_time_series_ref, args):
"""Update a Tensorboard time series."""
tensorboard_time_series = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries())
update_mask = []
if args.display_name is not None:
tensorboard_time_series.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard_time_series.description = args.description
update_mask.append('description')
if args.plugin_name is not None:
tensorboard_time_series.pluginName = args.plugin_name
update_mask.append('plugin_name')
if args.plugin_data is not None:
tensorboard_time_series.pluginData = bytes(
args.plugin_data, encoding='utf8')
update_mask.append('plugin_data')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesPatchRequest(
name=tensorboard_time_series_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardTimeSeries=tensorboard_time_series,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Read(self, tensorboard_time_series_ref, max_data_points, data_filter):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesReadRequest(
tensorboardTimeSeries=tensorboard_time_series_ref.RelativeName(),
maxDataPoints=max_data_points,
filter=data_filter)
return self._service.Read(request)

View File

@@ -0,0 +1,155 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboards API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai import validation as common_validation
from googlecloudsdk.command_lib.util.args import labels_util
class TensorboardsClient(object):
"""High-level client for the AI Platform Tensorboard surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards
self._version = version
def Create(self, location_ref, args):
if self._version == constants.GA_VERSION:
return self.CreateGa(location_ref, args)
else:
return self.CreateBeta(location_ref, args)
def CreateGa(self, location_ref, args):
"""Create a new Tensorboard."""
kms_key_name = common_validation.GetAndValidateKmsKey(args)
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1Tensorboard.LabelsValue)
tensorboard = self.messages.GoogleCloudAiplatformV1Tensorboard(
displayName=args.display_name,
description=args.description,
labels=labels)
if kms_key_name is not None:
tensorboard.encryptionSpec = api_util.GetMessage(
'EncryptionSpec', self._version)(
kmsKeyName=kms_key_name)
request = self.messages.AiplatformProjectsLocationsTensorboardsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1Tensorboard=tensorboard)
return self._service.Create(request)
def CreateBeta(self, location_ref, args):
"""Create a new Tensorboard."""
kms_key_name = common_validation.GetAndValidateKmsKey(args)
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1beta1Tensorboard.LabelsValue)
tensorboard = self.messages.GoogleCloudAiplatformV1beta1Tensorboard(
displayName=args.display_name,
description=args.description,
labels=labels)
if kms_key_name is not None:
tensorboard.encryptionSpec = api_util.GetMessage(
'EncryptionSpec', self._version)(
kmsKeyName=kms_key_name)
request = self.messages.AiplatformProjectsLocationsTensorboardsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1Tensorboard=tensorboard)
return self._service.Create(request)
def Get(self, tensorboard_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsGetRequest(
name=tensorboard_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=1000, page_size=50, region_ref=None, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsListRequest(
parent=region_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboards',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Delete(self, tensorboard_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsDeleteRequest(
name=tensorboard_ref.RelativeName())
return self._service.Delete(request)
def Patch(self, tensorboard_ref, args):
"""Update a Tensorboard."""
if self._version == constants.GA_VERSION:
tensorboard = self.messages.GoogleCloudAiplatformV1Tensorboard()
labels_value = self.messages.GoogleCloudAiplatformV1Tensorboard.LabelsValue
else:
tensorboard = self.messages.GoogleCloudAiplatformV1beta1Tensorboard()
labels_value = self.messages.GoogleCloudAiplatformV1beta1Tensorboard.LabelsValue
update_mask = []
def GetLabels():
return self.Get(tensorboard_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(args, labels_value,
GetLabels)
if labels_update.needs_update:
tensorboard.labels = labels_update.labels
update_mask.append('labels')
if args.display_name is not None:
tensorboard.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard.description = args.description
update_mask.append('description')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
if self._version == constants.GA_VERSION:
req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest(
name=tensorboard_ref.RelativeName(),
googleCloudAiplatformV1Tensorboard=tensorboard,
updateMask=','.join(update_mask))
else:
req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest(
name=tensorboard_ref.RelativeName(),
googleCloudAiplatformV1beta1Tensorboard=tensorboard,
updateMask=','.join(update_mask))
return self._service.Patch(req)

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities for dealing with Vertex AI api messages."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
def GetMessagesModule(version=constants.GA_VERSION):
"""Returns message module of the corresponding API version."""
return apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
def GetMessage(message_name, version=constants.GA_VERSION):
"""Returns the Vertex AI api messages class by name."""
return getattr(
GetMessagesModule(version), '{prefix}{name}'.format(
prefix=constants.AI_PLATFORM_MESSAGE_PREFIX[version],
name=message_name), None)

View File

@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utility functions for getting the alloydb API client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
from googlecloudsdk.core import log
from googlecloudsdk.core import resources
# API version constants
DEFAULT_RELEASE_TRACK = base.ReleaseTrack.ALPHA
VERSION_MAP = {
base.ReleaseTrack.ALPHA: 'v1alpha',
base.ReleaseTrack.BETA: 'v1beta',
base.ReleaseTrack.GA: 'v1',
}
API_VERSION_DEFAULT = VERSION_MAP[DEFAULT_RELEASE_TRACK]
class AlloyDBClient(object):
"""Wrapper for alloydb API client and associated resources."""
def __init__(self, release_track):
api_version = VERSION_MAP[release_track]
self.release_track = release_track
self.alloydb_client = apis.GetClientInstance('alloydb', api_version)
self.alloydb_messages = self.alloydb_client.MESSAGES_MODULE
self.resource_parser = resources.Registry()
self.resource_parser.RegisterApiByName('alloydb', api_version)
def GetMessagesModule(release_track):
"""Returns the message module for release track."""
api_version = VERSION_MAP[release_track]
return apis.GetMessagesModule('alloydb', api_version)
def YieldFromListHandlingUnreachable(*args, **kwargs):
"""Yields from paged List calls handling unreachable."""
unreachable = set()
def _GetFieldFn(message, attr):
unreachable.update(message.unreachable)
return getattr(message, attr)
result = list_pager.YieldFromList(get_field_func=_GetFieldFn, *args, **kwargs)
for item in result:
yield item
if unreachable:
log.warning(
'The following locations were unreachable: %s',
', '.join(sorted(unreachable)),
)

View File

@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlloyDB backup operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.alloydb import api_util
from googlecloudsdk.api_lib.util import waiter
def Await(op_ref, message, release_track, creates_resource=True):
"""Waits for the given google.longrunning.Operation to complete.
Args:
op_ref: The operation to poll.
message: String to display for default progress_tracker.
release_track: The API release track (e.g. ALPHA, BETA, etc.)
creates_resource: Whether or not the operation creates a resource.
Raises:
apitools.base.py.HttpError: If the request returns an HTTP error.
Returns:
The Operation or the Resource the Operation is associated with.
"""
client = api_util.AlloyDBClient(release_track)
alloydb_client = client.alloydb_client
if creates_resource:
poller = waiter.CloudOperationPoller(
alloydb_client.projects_locations_backups,
alloydb_client.projects_locations_operations)
else:
poller = waiter.CloudOperationPollerNoResources(
alloydb_client.projects_locations_operations)
return waiter.WaitFor(
poller, op_ref, message, exponential_sleep_multiplier=1.0, sleep_ms=10000)

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlloyDB cluster operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.alloydb import api_util
from googlecloudsdk.api_lib.util import waiter
def Await(op_ref, message, release_track, creates_resource=True):
"""Waits for the given google.longrunning.Operation to complete.
Args:
op_ref: The operation to poll.
message: String to display for default progress_tracker.
release_track: The API release track (e.g. ALPHA, BETA, etc.)
creates_resource: Whether or not the operation creates a resource.
Raises:
apitools.base.py.HttpError: If the request returns an HTTP error.
Returns:
The Operation or the Resource the Operation is associated with.
"""
client = api_util.AlloyDBClient(release_track)
alloydb_client = client.alloydb_client
if creates_resource:
poller = waiter.CloudOperationPoller(
alloydb_client.projects_locations_clusters,
alloydb_client.projects_locations_operations)
else:
poller = waiter.CloudOperationPollerNoResources(
alloydb_client.projects_locations_operations)
return waiter.WaitFor(poller, op_ref, message)

View File

@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlloyDB instance operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.alloydb import api_util
from googlecloudsdk.api_lib.util import waiter
def Await(op_ref, message, release_track, creates_resource=True):
"""Waits for the given google.longrunning.Operation to complete.
Args:
op_ref: The operation to poll.
message: String to display for default progress_tracker.
release_track: The API release track (e.g. ALPHA, BETA, etc.)
creates_resource: Whether or not the operation creates a resource
Raises:
apitools.base.py.HttpError: if the request returns an HTTP error
Returns:
The Operation or the Resource the Operation is associated with.
"""
client = api_util.AlloyDBClient(release_track)
alloydb_client = client.alloydb_client
if creates_resource:
poller = waiter.CloudOperationPoller(
alloydb_client.projects_locations_clusters_instances,
alloydb_client.projects_locations_operations)
else:
poller = waiter.CloudOperationPollerNoResources(
alloydb_client.projects_locations_operations)
return waiter.WaitFor(
poller, op_ref, message, exponential_sleep_multiplier=1.0, sleep_ms=10000)

View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for interaction with Api Config CRUD on API Gateway API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.api_gateway import base
from googlecloudsdk.calliope import exceptions as calliope_exceptions
from googlecloudsdk.command_lib.api_gateway import common_flags
class ApiConfigClient(base.BaseClient):
"""Client for Api Config objects on Cloud API Gateway API."""
def __init__(self, client=None):
base.BaseClient.__init__(
self,
client=client,
message_base='ApigatewayProjectsLocationsApisConfigs',
service_name='projects_locations_apis_configs')
self.DefineDelete()
self.DefineList('apiConfigs')
self.DefineUpdate('apigatewayApiConfig')
self.supported_views = {
'FULL':
self.messages.ApigatewayProjectsLocationsApisConfigsGetRequest
.ViewValueValuesEnum.FULL,
'BASIC':
self.messages.ApigatewayProjectsLocationsApisConfigsGetRequest
.ViewValueValuesEnum.BASIC
}
def Create(self, api_config_ref, display_name=None, labels=None,
backend_auth=None, managed_service_configs=None,
grpc_service_defs=None, open_api_docs=None):
"""Creates an Api Config object.
Args:
api_config_ref: A parsed resource reference for the api
display_name: Optional string display name
labels: Optional cloud labels (as provided in the labels argument)
backend_auth: Optional string to set the service account for backend auth
managed_service_configs: Optional field to send in a list of managed
service configurations. Should be in the form of the
ApigatewayApiConfigFileMessage's generated from the discovery document
grpc_service_defs: Optional field to send in a list of GRPC service
definitions. Should be in the form of
ApigatewayApiConfigGrpcServiceDefinition's generated from the discovery
document
open_api_docs: Optional field to send in a list of Open API documents.
Should be in the form of ApigatewayApiConfigOpenApiDocument's generated
from the discovery document
Returns:
Long running operation
"""
labels = common_flags.ProcessLabelsFlag(
labels,
self.messages.ApigatewayApiConfig.LabelsValue)
api_config = self.messages.ApigatewayApiConfig(
name=api_config_ref.RelativeName(),
displayName=display_name,
labels=labels,
gatewayServiceAccount=backend_auth,
managedServiceConfigs=managed_service_configs,
grpcServices=grpc_service_defs,
openapiDocuments=open_api_docs)
req = self.create_request(
apiConfigId=api_config_ref.Name(),
apigatewayApiConfig=api_config,
parent=api_config_ref.Parent().RelativeName())
return self.service.Create(req)
def Get(self, api_config_ref, view=None):
"""Returns an API Config object.
Args:
api_config_ref: A parsed resource reference for the API.
view: Optional string. If specified as FULL, the source config files will
be returned.
Returns:
An API Config object.
Raises:
calliope.InvalidArgumentException: If an invalid view (i.e. not FULL,
BASIC, or none) was
provided.
"""
view_enum = None
if view is not None:
try:
view_enum = self.supported_views[view.upper()]
except KeyError:
raise calliope_exceptions.InvalidArgumentException(
'--view', 'View must be one of: "FULL" or "BASIC".')
req = self.get_request(name=api_config_ref.RelativeName(), view=view_enum)
return self.service.Get(req)

View File

@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for interaction with Api CRUD on API Gateway API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import exceptions as apitools_exceptions
from googlecloudsdk.api_lib.api_gateway import base
from googlecloudsdk.command_lib.api_gateway import common_flags
class ApiClient(base.BaseClient):
"""Client for Api objects on Cloud API Gateway API."""
def __init__(self, client=None):
base.BaseClient.__init__(self,
client=client,
message_base='ApigatewayProjectsLocationsApis',
service_name='projects_locations_apis')
self.DefineGet()
self.DefineList('apis')
self.DefineUpdate('apigatewayApi')
self.DefineDelete()
self.DefineIamPolicyFunctions()
def DoesExist(self, api_ref):
"""Checks if an Api object exists.
Args:
api_ref: Resource, a resource reference for the api
Returns:
Boolean, indicating whether or not exists
"""
try:
self.Get(api_ref)
except apitools_exceptions.HttpNotFoundError:
return False
return True
def Create(self, api_ref, managed_service=None, labels=None,
display_name=None):
"""Creates a new Api object.
Args:
api_ref: Resource, a resource reference for the api
managed_service: Optional string, reference name for OP service
labels: Optional cloud labels
display_name: Optional display name
Returns:
Long running operation response object.
"""
labels = common_flags.ProcessLabelsFlag(
labels,
self.messages.ApigatewayApi.LabelsValue)
api = self.messages.ApigatewayApi(
name=api_ref.RelativeName(),
managedService=managed_service,
labels=labels,
displayName=display_name)
req = self.create_request(
apiId=api_ref.Name(),
apigatewayApi=api,
parent=api_ref.Parent().RelativeName())
return self.service.Create(req)

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for interaction with Gateway CRUD on API Gateway API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import types
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.iam import iam_util
def GetClientInstance(version='v1', no_http=False):
return apis.GetClientInstance('apigateway', version, no_http=no_http)
def GetMessagesModule(version='v1'):
return apis.GetMessagesModule('apigateway', version)
class BaseClient(object):
"""Base for building API Clients."""
def __init__(self, client=None, message_base=None, service_name=None):
self.client = client or GetClientInstance()
self.messages = self.client.MESSAGES_MODULE
self.service = getattr(self.client, service_name, None)
# Define standard request types if they exist for the base message
self.get_request = getattr(self.messages, message_base + 'GetRequest', None)
self.create_request = getattr(self.messages,
message_base + 'CreateRequest',
None)
self.list_request = getattr(self.messages,
message_base + 'ListRequest',
None)
self.patch_request = getattr(self.messages,
message_base + 'PatchRequest',
None)
self.delete_request = getattr(self.messages,
message_base + 'DeleteRequest',
None)
# Define IAM request types if they exist for the base message
self.get_iam_policy_request = getattr(self.messages,
message_base + 'GetIamPolicyRequest',
None)
self.set_iam_policy_request = getattr(self.messages,
message_base + 'SetIamPolicyRequest',
None)
def DefineGet(self):
"""Defines basic get function on an assigned class."""
def Get(self, object_ref):
"""Gets an object.
Args:
self: The self of the class this is set on.
object_ref: Resource, resource reference for object to get.
Returns:
The object requested.
"""
req = self.get_request(name=object_ref.RelativeName())
return self.service.Get(req)
# Bind the function to the method and set the attribute
setattr(self, 'Get', types.MethodType(Get, self))
def DefineDelete(self):
"""Defines basic delete function on an assigned class."""
def Delete(self, object_ref):
"""Deletes a given object given an object name.
Args:
self: The self of the class this is set on.
object_ref: Resource, resource reference for object to delete.
Returns:
Long running operation.
"""
req = self.delete_request(name=object_ref.RelativeName())
return self.service.Delete(req)
# Bind the function to the method and set the attribute
setattr(self, 'Delete', types.MethodType(Delete, self))
def DefineList(self, field_name, is_operations=False):
"""Defines the List functionality on the calling class.
Args:
field_name: The name of the field on the list response to list
is_operations: Operations have a slightly altered message structure, set
to true in operations client
"""
def List(self, parent_name, filters=None, limit=None, page_size=None,
sort_by=None):
"""Lists the objects under a given parent.
Args:
self: the self object function will be bound to.
parent_name: Resource name of the parent to list under.
filters: Filters to be applied to results (optional).
limit: Limit to the number of results per page (optional).
page_size: the number of results per page (optional).
sort_by: Instructions about how to sort the results (optional).
Returns:
List Pager.
"""
if is_operations:
req = self.list_request(filter=filters, name=parent_name)
else:
req = self.list_request(filter=filters, parent=parent_name,
orderBy=sort_by)
return list_pager.YieldFromList(
self.service,
req,
limit=limit,
batch_size_attribute='pageSize',
batch_size=page_size,
field=field_name)
# Bind the function to the method and set the attribute
setattr(self, 'List', types.MethodType(List, self))
def DefineUpdate(self, update_field_name):
"""Defines the Update functionality on the calling class.
Args:
update_field_name: the field on the patch_request to assign updated object
to
"""
def Update(self, updating_object, update_mask=None):
"""Updates an object.
Args:
self: The self of the class this is set on.
updating_object: Object which is being updated.
update_mask: A string saying which fields have been updated.
Returns:
Long running operation.
"""
req = self.patch_request(name=updating_object.name,
updateMask=update_mask)
setattr(req, update_field_name, updating_object)
return self.service.Patch(req)
# Bind the function to the method and set the attribute
setattr(self, 'Update', types.MethodType(Update, self))
def DefineIamPolicyFunctions(self):
"""Defines all of the IAM functionality on the calling class."""
def GetIamPolicy(self, object_ref):
"""Gets an IAM Policy on an object.
Args:
self: The self of the class this is set on.
object_ref: Resource, reference for object IAM policy belongs to.
Returns:
The IAM policy.
"""
req = self.get_iam_policy_request(resource=object_ref.RelativeName())
return self.service.GetIamPolicy(req)
def SetIamPolicy(self, object_ref, policy, update_mask=None):
"""Sets an IAM Policy on an object.
Args:
self: The self of the class this is set on.
object_ref: Resource, reference for object IAM policy belongs to.
policy: the policy to be set.
update_mask: fields being update on the IAM policy.
Returns:
The IAM policy.
"""
policy_request = self.messages.ApigatewaySetIamPolicyRequest(
policy=policy,
updateMask=update_mask)
req = self.set_iam_policy_request(
apigatewaySetIamPolicyRequest=policy_request,
resource=object_ref.RelativeName())
return self.service.SetIamPolicy(req)
def AddIamPolicyBinding(self, object_ref, member, role):
"""Adds an IAM role to a member on an object.
Args:
self: The self of the class this is set on.
object_ref: Resource, reference for object IAM policy belongs to.
member: the member the binding is being added to.
role: the role which to bind to the member.
Returns:
The IAM policy.
"""
policy = self.GetIamPolicy(object_ref)
iam_util.AddBindingToIamPolicy(self.messages.ApigatewayBinding, policy,
member, role)
return self.SetIamPolicy(object_ref, policy, 'bindings,etag')
def RemoveIamPolicyBinding(self, object_ref, member, role):
"""Adds an IAM role for a member on an object.
Args:
self: The self of the class this is set on
object_ref: Resource, reference for object IAM policy belongs to
member: the member the binding is removed for
role: the role which is being removed from the member
Returns:
The IAM policy
"""
policy = self.GetIamPolicy(object_ref)
iam_util.RemoveBindingFromIamPolicy(policy, member, role)
return self.SetIamPolicy(object_ref, policy, 'bindings,etag')
# Bind the function to the method and set the attribute
setattr(self, 'GetIamPolicy', types.MethodType(GetIamPolicy, self))
setattr(self, 'SetIamPolicy', types.MethodType(SetIamPolicy, self))
setattr(self, 'AddIamPolicyBinding', types.MethodType(AddIamPolicyBinding,
self))
setattr(self, 'RemoveIamPolicyBinding', types.MethodType(
RemoveIamPolicyBinding, self))

View File

@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for interaction with Gateway CRUD on API Gateway API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.api_gateway import base
from googlecloudsdk.command_lib.api_gateway import common_flags
class GatewayClient(base.BaseClient):
"""Client for gateway objects on Cloud API Gateway API."""
def __init__(self, client=None):
base.BaseClient.__init__(self,
client=client,
message_base='ApigatewayProjectsLocationsGateways',
service_name='projects_locations_gateways')
self.DefineGet()
self.DefineDelete()
self.DefineList('gateways')
self.DefineUpdate('apigatewayGateway')
self.DefineIamPolicyFunctions()
def Create(self, gateway_ref, api_config, display_name=None, labels=None):
"""Creates a new gateway object.
Args:
gateway_ref: Resource, a resource reference for the gateway
api_config: Resource, a resource reference for the gateway
display_name: Optional display name
labels: Optional cloud labels
Returns:
Long running operation.
"""
labels = common_flags.ProcessLabelsFlag(
labels,
self.messages.ApigatewayGateway.LabelsValue)
gateway = self.messages.ApigatewayGateway(
name=gateway_ref.RelativeName(),
labels=labels,
apiConfig=api_config.RelativeName(),
displayName=display_name,
)
req = self.create_request(
parent=gateway_ref.Parent().RelativeName(),
gatewayId=gateway_ref.Name(),
apigatewayGateway=gateway,
)
resp = self.service.Create(req)
return resp

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for interaction with Operations CRUD on API Gateway API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.api_gateway import base
from googlecloudsdk.api_lib.util import waiter
class OperationsClient(base.BaseClient):
"""Client for operation objects on Cloud API Gateway API."""
def __init__(self, client=None):
base.BaseClient.__init__(
self,
client=client,
message_base='ApigatewayProjectsLocationsOperations',
service_name='projects_locations_operations')
self.DefineGet()
self.DefineList('operations', is_operations=True)
def Cancel(self, operation_ref):
"""Cancel an operation.
Args:
operation_ref: The message to process (expected to be of type Operation)
Returns:
(Empty) The response message.
"""
req = self.messages.ApigatewayProjectsLocationsOperationsCancelRequest(
name=operation_ref.RelativeName())
return self.service.Cancel(req)
def WaitForOperation(self, operation_ref, message=None, service=None):
"""Waits for the given google.longrunning.Operation to complete.
Args:
operation_ref: The operation to poll.
message: String to display for default progress_tracker.
service: The service to get the resource after the long running operation
completes.
Raises:
apitools.base.py.HttpError: if the request returns an HTTP error
Returns:
The Operation or the Resource the Operation is associated with.
"""
# Consumers of OperationsClient can be resource-aware and if so, they can
# provide the service used for interacting with the Resource the Operation
# is associated with. In this case, OperationsClient#WaitForOperation will
# return the Resource the polled Operation is associated with. Otherwise,
# no service is provided and the Operation object itself is returned.
#
# Example: `gateways create` is resource-aware and returns an
# ApigatewayGateway while `operations wait` is not resource-aware and will
# return the Operation itself.
if service is None:
poller = waiter.CloudOperationPollerNoResources(
self.service)
else:
poller = waiter.CloudOperationPoller(
service,
self.service)
if message is None:
message = 'Waiting for Operation [{}] to complete'.format(
operation_ref.RelativeName())
return waiter.WaitFor(poller, operation_ref, message)

View File

@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The api_registry command group for the Cloud API Registry API."""
from googlecloudsdk.calliope import base
# NOTE: Release track decorators can be used here as well, and would propagate
# to this group's children.
class ApiRegistry(base.Group):
"""Manage API Registry Command Group."""

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The mcp command group for the Cloud API Registry API."""
from googlecloudsdk.calliope import base
# NOTE: Release track decorators can be used here as well, and would propagate
# to this group's children.
class Mcp(base.Group):
"""Manage API Registry MCP Command Group.
This command group is used to enable and disable MCP enablement for a given
service in the current project.
The current library contains utilitiy functions that are used by the
enablement and disablement commands.
"""

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for MCP command tests."""
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py.testing import mock
from googlecloudsdk.api_lib.util import apis as core_apis
from googlecloudsdk.calliope import base as calliope_base
from googlecloudsdk.core import properties
from tests.lib import cli_test_base
from tests.lib import parameterized
from tests.lib import sdk_test_base
class McpTestBaseForEnableDisableTests(
sdk_test_base.WithFakeAuth,
parameterized.TestCase,
cli_test_base.CliTestBase,
):
"""Base class for MCP enable and disable command tests."""
def SetUp(self):
self.project = 'test-gcp-project-12345'
properties.VALUES.core.project.Set(self.project)
self.su_messages = core_apis.GetMessagesModule('serviceusage', 'v2beta')
self.mocked_su_client = mock.Client(
core_apis.GetClientClass('serviceusage', 'v2beta'),
real_client=core_apis.GetClientInstance(
'serviceusage', 'v2beta', no_http=True))
self.mocked_su_client.Mock()
self.addCleanup(self.mocked_su_client.Unmock)
def _MakeHttpError(self, status, message='error'):
return apitools_exceptions.HttpError({'status': status}, message, '')
def _expectGetMcpPolicyCall(self, project, policy_old, exception=None):
expected_name = f'projects/{project}/mcpPolicies/default'
expected_request = self.su_messages.ServiceusageMcpPoliciesGetRequest(
name=expected_name
)
self.mocked_su_client.mcpPolicies.Get.Expect(
request=expected_request,
response=policy_old if not exception else None,
exception=exception,
)
def _expectUpdateMcpPolicyCall(
self, policy_new, operation_name, exception=None):
expected_request = self.su_messages.ServiceusageMcpPoliciesPatchRequest(
mcpPolicy=policy_new,
force=False,
name='projects/test-gcp-project-12345/mcpPolicies/default',
validateOnly=False,
)
mock_operation = self.su_messages.Operation(
name=operation_name,
done=False # Typically starts as not done
)
self.mocked_su_client.mcpPolicies.Patch.Expect(
request=expected_request,
response=mock_operation if not exception else None,
exception=exception,
)
def _expectGetOperationCall(self, operation_name, policy_new, exception=None):
expected_request = self.su_messages.ServiceusageOperationsGetRequest(
name=operation_name
)
response_value = encoding.PyValueToMessage(
self.su_messages.Operation.ResponseValue,
encoding.MessageToPyValue(policy_new)
)
response_op = None
if not exception:
response_op = self.su_messages.Operation(
name=operation_name,
done=True,
response=response_value
)
self.mocked_su_client.operations.Get.Expect(
request=expected_request,
response=response_op,
exception=exception,
)
def _expectGetServiceCall(self, project, service_name, service_state,
exception=None):
expected_name = f'projects/{project}/services/{service_name}'
expected_request = self.su_messages.ServiceusageServicesGetRequest(
name=expected_name,
view=self.su_messages.ServiceusageServicesGetRequest.ViewValueValuesEnum.SERVICE_STATE_VIEW_FULL
)
self.mocked_su_client.services.Get.Expect(
request=expected_request,
response=service_state if not exception else None,
exception=exception,
)
class McpAlphaForEnableDisableTests(McpTestBaseForEnableDisableTests):
"""Base class for MCP enable and disable command tests in alpha track."""
def PreSetUp(self):
self.track = calliope_base.ReleaseTrack.ALPHA
class McpBetaForEnableDisableTests(McpTestBaseForEnableDisableTests):
"""Base class for MCP enable and disable command tests in beta track."""
def PreSetUp(self):
self.track = calliope_base.ReleaseTrack.BETA

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for MCP Servers API client."""
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.api_registry import utils
class McpServersClient(object):
"""Client for MCP Servers API."""
def __init__(self, version, client=None, messages=None):
self.client = client or utils.GetClientInstance(version=version)
self.messages = messages or utils.GetMessagesModule(
version, client=self.client
)
self._service = self.client.projects_locations_mcpServers
def ListAlpha(self, request, args):
"""List MCP Servers in the API Registry.
Args:
request: (CloudapiregistryProjectsLocationsMcpServersListRequest) input
message
args: (arg_parsers.ArgumentParser) command line arguments
Returns:
A list of MCP Servers.
"""
filter_str = 'enabled=true'
if args.all:
filter_str = 'enabled=false'
list_req = (
self.messages.
CloudapiregistryProjectsLocationsMcpServersListRequest(
parent=request, filter=filter_str))
return list_pager.YieldFromList(
self._service,
list_req,
field='mcpServers',
batch_size_attribute='pageSize')
def ListBeta(self, request, args):
"""List MCP Servers in the API Registry.
Args:
request:
(CloudapiregistryProjectsLocationsMcpServersListRequest)
input message
args:
(arg_parsers.ArgumentParser)
command line arguments
Returns:
A list of MCP Servers.
"""
filter_str = 'enabled=true'
if args.all:
filter_str = 'enabled=false'
list_req = (
self.messages.
CloudapiregistryProjectsLocationsMcpServersListRequest(
parent=request, filter=filter_str))
return list_pager.YieldFromList(
self._service,
list_req,
field='mcpServers',
batch_size_attribute='pageSize')

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for MCP Tools API client."""
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.api_registry import utils
class McpToolsClient(object):
"""Client for MCP Tools API."""
def __init__(self, version, client=None, messages=None):
self.client = client or utils.GetClientInstance(version=version)
self.messages = messages or utils.GetMessagesModule(
version, client=self.client
)
self._service = self.client.projects_locations_mcpServers_mcpTools
def ListAlpha(self, request, args):
"""List MCP Tools in the API Registry.
Args:
request: (CloudapiregistryProjectsLocationsMcpServersMcpToolsListRequest)
input message
args: (arg_parsers.ArgumentParser) command line arguments
Returns:
A list of MCP Tools.
"""
# TODO: b/460124490 - Add UTs for api_lib files too.
filter_str = 'enabled=true'
if args.all:
filter_str = 'enabled=false'
list_req = (
self.messages.
CloudapiregistryProjectsLocationsMcpServersMcpToolsListRequest(
parent=request, filter=filter_str))
return list_pager.YieldFromList(
self._service,
list_req,
field='mcpTools',
batch_size_attribute='pageSize')
def ListBeta(self, request, args):
"""List MCP Tools in the API Registry.
Args:
request:
(CloudapiregistryProjectsLocationsMcpServersMcpToolsListRequest)
input message
args:
(arg_parsers.ArgumentParser)
command line arguments
Returns:
A list of MCP Tools.
"""
# TODO: b/460124490 - Add UTs for api_lib files too.
filter_str = 'enabled=true'
if args.all:
filter_str = 'enabled=false'
list_req = (
self.messages.
CloudapiregistryProjectsLocationsMcpServersMcpToolsListRequest(
parent=request, filter=filter_str))
return list_pager.YieldFromList(
self._service,
list_req,
field='mcpTools',
batch_size_attribute='pageSize')

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for MCP Servers and Tools API."""
from googlecloudsdk.core import properties
_PROJECT_RESOURCE = 'projects/{}'
_MCP_POLICY_DEFAULT = '/mcpPolicies/default'
# Returns the project resource in the format of projects/{project_id}
# for the current project.
@staticmethod
def GetProjectResource():
"""Returns the project ID for the current project."""
project_id = GetProjectId()
project_resource = _PROJECT_RESOURCE.format(project_id)
return project_resource
# Returns the project ID for the current project.
@staticmethod
def GetProjectId():
"""Returns the project ID for the current project."""
project_id = properties.VALUES.core.project.Get()
return project_id
# Returns the format for the MCP Policy Default in the format of
# projects/{project_id}/mcpPolicies/default.
@staticmethod
def GetMcpPolicyDefault():
"""Returns the MCP Policy Default."""
return _MCP_POLICY_DEFAULT

View File

@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for MCP Servers and Tools API."""
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import properties
_CLOUD_API_REGISTRY_API = 'cloudapiregistry'
def GetClientInstance(version, no_http=False):
return apis.GetClientInstance(
_CLOUD_API_REGISTRY_API, version, no_http=no_http)
def GetMessagesModule(version, client=None):
client = client or GetClientInstance(version=version)
return client.MESSAGES_MODULE
def GetProject():
return properties.VALUES.core.project.GetOrFail()
def GetLocation():
# Since API Registry is a global service right now.
return 'global'

View File

@@ -0,0 +1,430 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python wrappers around Apigee Management APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import collections
import json
import re
from googlecloudsdk.api_lib.apigee import base
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.command_lib.apigee import errors
from googlecloudsdk.command_lib.apigee import request
from googlecloudsdk.command_lib.apigee import resource_args
from googlecloudsdk.core import log
class OrganizationsClient(base.BaseClient):
"""REST client for Apigee Organizations."""
_entity_path = ["organization"]
@classmethod
def ListOrganizationsGlobal(cls):
"""Returns a list of Apigee organizations on global endpoint."""
try:
return request.ResponseToApiRequest(
identifiers=None,
entity_path=cls._entity_path[:-1],
entity_collection=cls._entity_path[-1],
location="global")
except errors.RequestError as error:
# Rewrite error message to better describe what was attempted.
raise error.RewrittenError("organization", "list")
@classmethod
def ProjectMapping(cls, identifiers):
"""Returns a mapping of GCP projects to Apigee organization."""
try:
return request.ResponseToApiRequest(
identifiers,
cls._entity_path,
method=":getProjectMapping",
method_override="GET",
location="global",
)
except (
errors.UnauthorizedRequestError,
errors.EntityNotFoundError,
) as error:
# Rewrite error message to better describe what was attempted.
raise error.RewrittenError("project mapping", "get")
except errors.RequestError as error:
# Rewrite error message to better describe what was attempted.
raise error.RewrittenError("project mapping", "get")
class APIsClient(base.BaseClient):
"""REST client for Apigee API Proxies."""
_entity_path = ["organization", "api"]
@classmethod
def Deploy(cls, identifiers, override=False):
deployment_path = ["organization", "environment", "api", "revision"]
query_params = {"override": "true"} if override else {}
try:
return request.ResponseToApiRequest(
identifiers,
deployment_path,
"deployment",
method="POST",
query_params=query_params)
except errors.RequestError as error:
# Rewrite error message to better describe what was attempted.
raise error.RewrittenError("API proxy", "deploy")
@classmethod
def Undeploy(cls, identifiers):
try:
return request.ResponseToApiRequest(
identifiers, ["organization", "environment", "api", "revision"],
"deployment",
method="DELETE")
except errors.RequestError as error:
# Rewrite error message to better describe what was attempted.
raise error.RewrittenError("deployment", "undeploy")
class EnvironmentsClient(base.BaseClient):
_entity_path = ["organization", "environment"]
class RevisionsClient(base.BaseClient):
_entity_path = ["organization", "api", "revision"]
class _DeveloperApplicationsClient(base.FieldPagedListClient):
_entity_path = ["organization", "developer", "app"]
_list_container = "app"
_page_field = "name"
class OperationsClient(base.BaseClient):
"""REST client for Apigee long running operations."""
_entity_path = ["organization", "operation"]
@classmethod
def SplitName(cls, operation_info):
name_parts = re.match(
r"organizations/([a-z][-a-z0-9]{0,30}[a-z0-9])/operations/"
r"([0-9a-fA-F]{8}-([0-9a-fA-F]{4}-){3}[0-9a-fA-F]{12})",
operation_info["name"])
if not name_parts:
return operation_info
operation_info["organization"] = name_parts.group(1)
operation_info["uuid"] = name_parts.group(2)
return operation_info
@classmethod
def List(cls, identifiers):
response = super(OperationsClient, cls).List(identifiers)
if not response:
return
for item in response["operations"]:
yield cls.SplitName(item)
@classmethod
def Describe(cls, identifiers):
return cls.SplitName(super(OperationsClient, cls).Describe(identifiers))
class ProjectsClient(base.BaseClient):
"""REST client for Apigee APIs related to GCP projects."""
_entity_path = ["project"]
@classmethod
def ProvisionOrganization(cls, project_id, org_info, location=None):
return request.ResponseToApiRequest({"projectsId": project_id}, ["project"],
method=":provisionOrganization",
body=json.dumps(org_info),
location=location)
class ApplicationsClient(base.FieldPagedListClient):
"""REST client for Apigee applications."""
_entity_path = ["organization", "app"]
_list_container = "app"
_page_field = "appId"
_limit_param = "rows"
@classmethod
def List(cls, identifiers):
if "developersId" in identifiers and identifiers["developersId"]:
list_implementation = _DeveloperApplicationsClient.List
expand_flag = "shallowExpand"
else:
list_implementation = super(ApplicationsClient, cls).List
expand_flag = "expand"
items = list_implementation(identifiers, extra_params={expand_flag: "true"})
for item in items:
yield {"appId": item["appId"], "name": item["name"]}
class DevelopersClient(base.FieldPagedListClient):
_entity_path = ["organization", "developer"]
_list_container = "developer"
_page_field = "email"
class DeploymentsClient(object):
"""REST client for Apigee deployments."""
@classmethod
def List(cls, identifiers):
"""Returns a list of deployments, filtered by `identifiers`.
The deployment-listing API, unlike most GCP APIs, is very flexible as to
what kinds of objects are provided as the deployments' parents. An
organization is required, but any combination of environment, proxy or
shared flow, and API revision can be given in addition to that.
Args:
identifiers: dictionary with fields that describe which deployments to
list. `organizationsId` is required. `environmentsId`, `apisId`, and
`revisionsId` can be optionally provided to further filter the list.
Shared flows are not yet supported.
Returns:
A list of Apigee deployments, each represented by a parsed JSON object.
"""
identifier_names = ["organization", "environment", "api", "revision"]
entities = [resource_args.ENTITIES[name] for name in identifier_names]
entity_path = []
for entity in entities:
key = entity.plural + "Id"
if key in identifiers and identifiers[key] is not None:
entity_path.append(entity.singular)
if "revision" in entity_path and "api" not in entity_path:
# Revision is notioinally a part of API proxy and can't be specified
# without it. Behave as though neither API proxy nor revision were given.
entity_path.remove("revision")
try:
response = request.ResponseToApiRequest(identifiers, entity_path,
"deployment")
except errors.EntityNotFoundError:
# If there were no matches, that's just an empty list of matches.
response = []
# The different endpoints this method can hit return different formats.
# Translate them all into a single format.
if "apiProxy" in response:
return [response]
if "deployments" in response:
return response["deployments"]
if not response:
return []
return response
ProductsInfo = collections.namedtuple("ProductsInfo", [
"name", "displayName", "approvalType", "attributes", "description",
"apiResources", "environments", "proxies", "quota", "quotaInterval",
"quotaTimeUnit", "scopes"
])
class ProductsClient(base.FieldPagedListClient):
"""REST client for Apigee API products."""
_entity_path = ["organization", "product"]
_list_container = "apiProduct"
_page_field = "name"
@classmethod
def Create(cls, identifiers, product_info):
product_dict = product_info._asdict()
# Don't send fields unless there's a value for them.
product_dict = {
key: product_dict[key]
for key in product_dict
if product_dict[key] is not None
}
return request.ResponseToApiRequest(
identifiers, ["organization"],
"product",
method="POST",
body=json.dumps(product_dict))
@classmethod
def Update(cls, identifiers, product_info):
product_dict = product_info._asdict()
# Don't send fields unless there's a value for them.
product_dict = {
key: product_dict[key]
for key in product_dict
if product_dict[key] is not None
}
return request.ResponseToApiRequest(
identifiers, ["organization", "product"],
method="PUT",
body=json.dumps(product_dict))
class ArchivesClient(base.TokenPagedListClient):
"""Client for the Apigee archiveDeployments API."""
# These are the entity names used internally by gcloud.
_entity_path = ["organization", "environment", "archive_deployment"]
_list_container = "archiveDeployments"
@classmethod
def Update(cls, identifiers, labels):
"""Calls the 'update' API for archive deployments.
Args:
identifiers: Dict of identifiers for the request entity path, which must
include "organizationsId", "environmentsId" and "archiveDeploymentsId".
labels: Dict of the labels proto to update, in the form of:
{"labels": {"key1": "value1", "key2": "value2", ... "keyN": "valueN"}}
Returns:
A dict of the updated archive deployment.
Raises:
command_lib.apigee.errors.RequestError if there is an error with the API
request.
"""
try:
return request.ResponseToApiRequest(
identifiers,
entity_path=cls._entity_path,
method="PATCH",
body=json.dumps(labels))
except errors.RequestError as error:
raise error.RewrittenError("archive deployment", "update")
@classmethod
def List(cls, identifiers):
"""Calls the 'list' API for archive deployments.
Args:
identifiers: Dict of identifiers for the request entity path, which must
include "organizationsId" and "environmentsId".
Returns:
An iterable of archive deployments.
Raises:
command_lib.apigee.errors.RequestError if there is an error with the API
request.
"""
try:
return super(ArchivesClient, cls).List(identifiers)
except errors.RequestError as error:
raise error.RewrittenError("archive deployment", "list")
@classmethod
def GetUploadUrl(cls, identifiers):
"""Apigee API for generating a signed URL for uploading archives.
This API uses the custom method:
organizations/*/environments/*/archiveDeployments:generateUploadUrl
Args:
identifiers: Dict of identifiers for the request entity path, which must
include "organizationsId" and "environmentsId".
Returns:
A dict of the API response in the form of:
{"uploadUri": "https://storage.googleapis.com/ ... (full URI)"}
Raises:
command_lib.apigee.errors.RequestError if there is an error with the API
request.
"""
try:
# The API call doesn't need to specify an archiveDeployment resource id,
# so only the "organizations/environments" entity path is needed.
# "archiveDeployment" is provided as the entity_collection argument.
return request.ResponseToApiRequest(
identifiers,
entity_path=cls._entity_path[:-1],
entity_collection=cls._entity_path[-1],
method=":generateUploadUrl")
except errors.RequestError as error:
raise error.RewrittenError("archive deployment", "get upload url for")
@classmethod
def CreateArchiveDeployment(cls, identifiers, post_data):
"""Apigee API for creating a new archive deployment.
Args:
identifiers: A dict of identifiers for the request entity path, which must
include "organizationsId" and "environmentsId".
post_data: A dict of the request body to include in the
CreateArchiveDeployment API call.
Returns:
A dict of the API response. The API call starts a long-running operation,
so the response dict will contain info about the operation id.
Raises:
command_lib.apigee.errors.RequestError if there is an error with the API
request.
"""
try:
# The API call doesn't need to specify an archiveDeployment resource name
# so only the "organizations/environments" entity path is needed.
# "archive_deployment" is provided as the entity_collection argument.
return request.ResponseToApiRequest(
identifiers,
cls._entity_path[:-1],
cls._entity_path[-1],
method="POST",
body=json.dumps(post_data))
except errors.RequestError as error:
raise error.RewrittenError("archive deployment", "create")
class LROPoller(waiter.OperationPoller):
"""Polls on completion of an Apigee long running operation."""
def __init__(self, organization):
super(LROPoller, self).__init__()
self.organization = organization
def IsDone(self, operation):
finished = False
try:
finished = (operation["metadata"]["state"] == "FINISHED")
except KeyError as err:
raise waiter.OperationError("Malformed operation; %s\n%r" %
(err, operation))
if finished and "error" in operation:
raise errors.RequestError(
"operation", {"name": operation["name"]},
"await",
body=json.dumps(operation))
return finished
def Poll(self, operation_uuid):
return OperationsClient.Describe({
"organizationsId": self.organization,
"operationsId": operation_uuid
})
def GetResult(self, operation):
if "response" in operation:
return operation["response"]
return None

View File

@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic implementations of Apigee Management APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.apigee import request
class BaseClient(object):
"""Base class for Apigee Management API clients."""
_entity_path = None
"""List of identifiers that uniquely identify the object.
Must be in the same order as the REST API expects.
"""
@classmethod
def List(cls, identifiers=None, extra_params=None):
if cls._entity_path is None:
raise NotImplementedError("%s class must provide an entity path." % cls)
return request.ResponseToApiRequest(
identifiers or {},
cls._entity_path[:-1],
cls._entity_path[-1],
query_params=extra_params)
@classmethod
def Describe(cls, identifiers=None):
if cls._entity_path is None:
raise NotImplementedError("%s class must provide an entity path." % cls)
return request.ResponseToApiRequest(identifiers or {}, cls._entity_path)
@classmethod
def Delete(cls, identifiers=None):
if cls._entity_path is None:
raise NotImplementedError("%s class must provide an entity path." % cls)
return request.ResponseToApiRequest(
identifiers or {}, cls._entity_path, method="DELETE")
class PagedListClient(BaseClient):
"""Client for `List` APIs that can only return a limited number of objects.
Attributes:
_list_container: the field name in the List API's response that contains the
list of objects. None if the API returns a list directly.
"""
_list_container = None
@classmethod
def _NormalizedResultChunk(cls, result_chunk):
"""Returns a list of the results in `result_chunk`."""
if cls._list_container is None:
return result_chunk
try:
return result_chunk[cls._list_container]
except KeyError:
failure_info = (cls, cls._list_container, result_chunk)
raise AssertionError(
"%s specifies a _list_container %r that's not present in API "
"responses.\nResponse: %r" % failure_info)
except (IndexError, TypeError):
error = ("%s specifies a _list_container, implying that the API "
"response should be a JSON object, but received something "
"else instead: %r") % (cls, result_chunk)
raise AssertionError(error)
class TokenPagedListClient(PagedListClient):
"""Client for paged `List` APIs that identify pages using a page token.
This is the AIP-approved way to paginate results and is preferred for new
APIs.
Attributes:
_page_token_field: the field name in the List API's response that contains
an explicit page token.
_list_container: the field name in the List API's response that contains the
list of objects.
_page_token_param: the query parameter for the previous page's token.
_max_per_page: the maximum number of items that can be returned in each List
response.
_limit_param: the query parameter for the number of items to be returned on
each page.
"""
_page_token_field = "nextPageToken"
_page_token_param = "pageToken"
_max_per_page = 100
_limit_param = "pageSize"
@classmethod
def List(cls, identifiers=None, extra_params=None):
if cls._list_container is None:
error = ("%s does not specify a _list_container, but token pagination "
"requires it") % (cls)
raise AssertionError(error)
params = {cls._limit_param: cls._max_per_page}
if extra_params:
params.update(extra_params)
while True:
response = super(TokenPagedListClient, cls).List(identifiers, params)
for item in cls._NormalizedResultChunk(response):
yield item
# A blank page token is the same as an omitted one.
if cls._page_token_field in response and response[cls._page_token_field]:
params[cls._page_token_param] = response[cls._page_token_field]
continue
# No page token? No more pages.
break
class FieldPagedListClient(PagedListClient):
"""Client for paged `List` APIs that identify pages using a page field.
This is the pagination method used by legacy Apigee CG APIs, and has been
preserved for backwards compatibility in Apigee's GCP offering.
Attributes:
_list_container: the field name in the List API's response that contains the
list of objects. None if the API returns a list directly.
_page_field: the field name in each list element that can be used as a page
identifier. PageListClient will take the value of this field in the last
list item for a page, and use it as the _start_at_param for the next
page. None if each list element is a primitive which can be used for this
purpose directly.
_max_per_page: the maximum number of items that can be returned in each List
response.
_limit_param: the query parameter for the number of items to be returned on
each page.
_start_at_param: the query parameter for where in the available data the
response should begin.
"""
_page_field = None
_max_per_page = 1000
_limit_param = "count"
_start_at_param = "startKey"
@classmethod
def List(cls, identifiers=None, start_at_param=None, extra_params=None):
if start_at_param is None:
start_at_param = cls._start_at_param
params = {cls._limit_param: cls._max_per_page}
if extra_params:
params.update(extra_params)
while True:
result_chunk = super(FieldPagedListClient, cls).List(identifiers, params)
if not result_chunk and start_at_param not in params:
# First request returned no rows; entire dataset is empty.
return
if cls._list_container is not None:
# This API is expected to return a dictionary with a list inside it.
# Extract the result list out of the dictionary for further processing.
result_chunk = cls._NormalizedResultChunk(result_chunk)
# For legacy pagination, the last item in a full page is also the first
# item in the next page. Don't yield it yet; the next page will yield it
# instead.
for item in result_chunk[:cls._max_per_page - 1]:
yield item
if len(result_chunk) < cls._max_per_page:
# Server didn't have enough values to fill the page, so all results have
# been received.
break
last_item_on_page = result_chunk[-1]
if cls._page_field is not None:
last_item_on_page = last_item_on_page[cls._page_field]
params[start_at_param] = last_item_on_page

View File

@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating a client to talk to the App Engine Admin API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis as core_apis
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
class AppengineApiClientBase(object):
"""Base class for App Engine API client."""
def __init__(self, client):
self.client = client
self.project = properties.VALUES.core.project.Get(required=True)
@property
def messages(self):
return self.client.MESSAGES_MODULE
@classmethod
def ApiVersion(cls):
return 'v1'
@classmethod
def GetApiClient(cls, api_version=None):
"""Initializes an AppengineApiClient using the specified API version.
Uses the api_client_overrides/appengine property to determine which client
version to use if api_version is not set. Additionally uses the
api_endpoint_overrides/appengine property to determine the server endpoint
for the App Engine API.
Args:
api_version: The api version override.
Returns:
An AppengineApiClient used by gcloud to communicate with the App Engine
API.
Raises:
ValueError: If default_version does not correspond to a supported version
of the API.
"""
if api_version is None:
api_version = cls.ApiVersion()
return cls(core_apis.GetClientInstance('appengine', api_version))
def _FormatApp(self):
res = resources.REGISTRY.Parse(
self.project, params={}, collection='appengine.apps')
return res.RelativeName()
def _GetServiceRelativeName(self, service_name):
res = resources.REGISTRY.Parse(
service_name,
params={'appsId': self.project},
collection='appengine.apps.services')
return res.RelativeName()
def _FormatVersion(self, service_name, version_id):
res = resources.REGISTRY.Parse(
version_id,
params={'appsId': self.project,
'servicesId': service_name},
collection='appengine.apps.services.versions')
return res.RelativeName()
def _FormatOperation(self, op_id):
res = resources.REGISTRY.Parse(
op_id,
params={'appsId': self.project},
collection='appengine.apps.operations')
return res.RelativeName()

View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating a client to talk to the App Engine Admin API."""
from googlecloudsdk.api_lib.app import operations_util
from googlecloudsdk.api_lib.app.api import appengine_api_client_base as base
from googlecloudsdk.calliope import base as calliope_base
from googlecloudsdk.core import log
from googlecloudsdk.core import resources
DEFAULT_VERSION = 'v1beta'
# 'app update' is currently only exposed in beta.
UPDATE_VERSIONS_MAP = {
calliope_base.ReleaseTrack.GA: DEFAULT_VERSION,
calliope_base.ReleaseTrack.ALPHA: DEFAULT_VERSION,
calliope_base.ReleaseTrack.BETA: DEFAULT_VERSION
}
def GetApiClientForTrack(release_track):
return AppengineAppUpdateApiClient.GetApiClient(
UPDATE_VERSIONS_MAP[release_track])
class AppengineAppUpdateApiClient(base.AppengineApiClientBase):
"""Client used by gcloud to communicate with the App Engine API."""
def __init__(self, client):
base.AppengineApiClientBase.__init__(self, client)
self._registry = resources.REGISTRY.Clone()
# pylint: disable=protected-access
self._registry.RegisterApiByName('appengine', client._VERSION)
def PatchApplication(
self, split_health_checks=None, service_account=None, ssl_policy=None
):
"""Updates an application.
Args:
split_health_checks: Boolean, whether to enable split health checks by
default.
service_account: str, the app-level default service account to update for
this App Engine app.
ssl_policy: enum, the app-level SSL policy to update for this App Engine
app. Can be DEFAULT or MODERN.
Returns:
Long running operation.
"""
# Create a configuration update request.
update_mask = ''
if split_health_checks is not None:
update_mask += 'featureSettings.splitHealthChecks,'
if service_account is not None:
update_mask += 'serviceAccount,'
if ssl_policy is not None:
update_mask += 'sslPolicy,'
application_update = self.messages.Application()
application_update.featureSettings = self.messages.FeatureSettings(
splitHealthChecks=split_health_checks)
application_update.serviceAccount = service_account
application_update.sslPolicy = ssl_policy
update_request = self.messages.AppengineAppsPatchRequest(
name=self._FormatApp(),
application=application_update,
updateMask=update_mask)
operation = self.client.apps.Patch(update_request)
log.debug('Received operation: [{operation}] with mask [{mask}]'.format(
operation=operation.name,
mask=update_mask))
return operations_util.WaitForOperation(self.client.apps_operations,
operation)

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating a client to talk to the App Engine Admin API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.app import operations_util
from googlecloudsdk.api_lib.app.api import appengine_api_client_base as base
from googlecloudsdk.calliope import base as calliope_base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.core import resources
DOMAINS_VERSION_MAP = {
calliope_base.ReleaseTrack.GA: 'v1',
calliope_base.ReleaseTrack.ALPHA: 'v1alpha',
calliope_base.ReleaseTrack.BETA: 'v1beta'
}
def GetApiClientForTrack(release_track):
return AppengineDomainsApiClient.GetApiClient(
DOMAINS_VERSION_MAP[release_track])
class AppengineDomainsApiClient(base.AppengineApiClientBase):
"""Client used by gcloud to communicate with the App Engine API."""
def __init__(self, client):
base.AppengineApiClientBase.__init__(self, client)
self._registry = resources.REGISTRY.Clone()
# pylint: disable=protected-access
self._registry.RegisterApiByName('appengine', client._VERSION)
def DeleteDomainMapping(self, domain):
"""Deletes a domain mapping for the given application.
Args:
domain: str, the domain to delete.
"""
request = self.messages.AppengineAppsDomainMappingsDeleteRequest(
name=self._FormatDomainMapping(domain))
operation = self.client.apps_domainMappings.Delete(request)
operations_util.WaitForOperation(self.client.apps_operations, operation)
def GetDomainMapping(self, domain):
"""Gets a domain mapping for the given application.
Args:
domain: str, the domain to retrieve.
Returns:
The retrieved DomainMapping object.
"""
request = self.messages.AppengineAppsDomainMappingsGetRequest(
name=self._FormatDomainMapping(domain))
return self.client.apps_domainMappings.Get(request)
def ListDomainMappings(self):
"""Lists all domain mappings for the given application.
Returns:
A list of DomainMapping objects.
"""
request = self.messages.AppengineAppsDomainMappingsListRequest(
parent=self._FormatApp())
response = self.client.apps_domainMappings.List(request)
return response.domainMappings
def ListVerifiedDomains(self):
"""Lists all domains verified by the current user.
Returns:
A list of AuthorizedDomain objects.
"""
request = self.messages.AppengineAppsAuthorizedDomainsListRequest(
parent=self._FormatApp())
response = self.client.apps_authorizedDomains.List(request)
return response.domains
def CreateDomainMapping(self, domain, certificate_id, management_type):
"""Creates a domain mapping for the given application.
Args:
domain: str, the custom domain string.
certificate_id: str, a certificate id for the new domain.
management_type: SslSettings.SslManagementTypeValueValuesEnum,
AUTOMATIC or MANUAL certificate provisioning.
Returns:
The created DomainMapping object.
"""
ssl = self.messages.SslSettings(certificateId=certificate_id,
sslManagementType=management_type)
domain_mapping = self.messages.DomainMapping(id=domain, sslSettings=ssl)
request = self.messages.AppengineAppsDomainMappingsCreateRequest(
parent=self._FormatApp(),
domainMapping=domain_mapping)
operation = self.client.apps_domainMappings.Create(request)
return operations_util.WaitForOperation(self.client.apps_operations,
operation).response
def UpdateDomainMapping(self,
domain,
certificate_id,
no_certificate_id,
management_type):
"""Updates a domain mapping for the given application.
Args:
domain: str, the custom domain string.
certificate_id: str, a certificate id for the domain.
no_certificate_id: bool, remove the certificate id from the domain.
management_type: SslSettings.SslManagementTypeValueValuesEnum,
AUTOMATIC or MANUAL certificate provisioning.
Returns:
The updated DomainMapping object.
"""
mask_fields = []
if certificate_id or no_certificate_id:
mask_fields.append('sslSettings.certificateId')
if management_type:
mask_fields.append('sslSettings.sslManagementType')
ssl = self.messages.SslSettings(
certificateId=certificate_id, sslManagementType=management_type)
domain_mapping = self.messages.DomainMapping(id=domain, sslSettings=ssl)
if not mask_fields:
raise exceptions.MinimumArgumentException(
['--[no-]certificate-id', '--no_managed_certificate'],
'Please specify at least one attribute to the domain-mapping update.')
request = self.messages.AppengineAppsDomainMappingsPatchRequest(
name=self._FormatDomainMapping(domain),
domainMapping=domain_mapping,
updateMask=','.join(mask_fields))
operation = self.client.apps_domainMappings.Patch(request)
return operations_util.WaitForOperation(self.client.apps_operations,
operation).response
def _FormatDomainMapping(self, domain):
res = self._registry.Parse(
domain,
params={'appsId': self.project},
collection='appengine.apps.domainMappings')
return res.RelativeName()

View File

@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating a client to talk to the App Engine Admin API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.app import util
from googlecloudsdk.api_lib.app.api import appengine_api_client_base as base
from googlecloudsdk.calliope import base as calliope_base
VERSION_MAP = {
calliope_base.ReleaseTrack.GA: 'v1',
calliope_base.ReleaseTrack.ALPHA: 'v1alpha',
calliope_base.ReleaseTrack.BETA: 'v1beta'
}
def GetApiClientForTrack(release_track):
api_version = VERSION_MAP[release_track]
return AppengineFirewallApiClient.GetApiClient(api_version)
class AppengineFirewallApiClient(base.AppengineApiClientBase):
"""Client used by gcloud to communicate with the App Engine API."""
def __init__(self, client):
base.AppengineApiClientBase.__init__(self, client)
def Create(self, priority, source_range, action, description):
"""Creates a firewall rule for the given application.
Args:
priority: int, the priority of the rule between [1, 2^31-1].
The default rule may not be created, only updated.
source_range: str, the ip address or range to take action on.
action: firewall_rules_util.Action, optional action to take on matched
addresses.
description: str, an optional string description of the rule.
Returns:
The new firewall rule.
"""
rule = self.messages.FirewallRule(
priority=priority,
action=action,
description=description,
sourceRange=source_range)
request = self.messages.AppengineAppsFirewallIngressRulesCreateRequest(
parent=self._FormatApp(), firewallRule=rule)
return self.client.apps_firewall_ingressRules.Create(request)
def Delete(self, resource):
"""Deletes a firewall rule for the given application.
Args:
resource: str, the resource path to the firewall rule.
"""
request = self.messages.AppengineAppsFirewallIngressRulesDeleteRequest(
name=resource.RelativeName())
self.client.apps_firewall_ingressRules.Delete(request)
def List(self, matching_address=None):
"""Lists all ingress firewall rules for the given application.
Args:
matching_address: str, an optional ip address to filter matching rules.
Returns:
A list of FirewallRule objects.
"""
request = self.messages.AppengineAppsFirewallIngressRulesListRequest(
parent=self._FormatApp(), matchingAddress=matching_address)
return list_pager.YieldFromList(
self.client.apps_firewall_ingressRules,
request,
field='ingressRules',
batch_size_attribute='pageSize')
def Get(self, resource):
"""Gets a firewall rule for the given application.
Args:
resource: str, the resource path to the firewall rule.
Returns:
A FirewallRule object.
"""
request = self.messages.AppengineAppsFirewallIngressRulesGetRequest(
name=resource.RelativeName())
response = self.client.apps_firewall_ingressRules.Get(request)
return response
def Update(self,
resource,
priority,
source_range=None,
action=None,
description=None):
"""Updates a firewall rule for the given application.
Args:
resource: str, the resource path to the firewall rule.
priority: int, the priority of the rule.
source_range: str, optional ip address or range to take action on.
action: firewall_rules_util.Action, optional action to take on matched
addresses.
description: str, optional string description of the rule.
Returns:
The updated firewall rule.
Raises:
NoFieldsSpecifiedError: when no fields have been specified for the update.
"""
mask_fields = []
if action:
mask_fields.append('action')
if source_range:
mask_fields.append('sourceRange')
if description:
mask_fields.append('description')
rule = self.messages.FirewallRule(
priority=priority,
action=action,
description=description,
sourceRange=source_range)
if not mask_fields:
raise util.NoFieldsSpecifiedError()
request = self.messages.AppengineAppsFirewallIngressRulesPatchRequest(
name=resource.RelativeName(),
firewallRule=rule,
updateMask=','.join(mask_fields))
return self.client.apps_firewall_ingressRules.Patch(request)

View File

@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating a client to talk to the App Engine Admin SSL APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.app.api import appengine_api_client_base as base
from googlecloudsdk.calliope import base as calliope_base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.core import resources
from googlecloudsdk.core.util import files
SSL_VERSIONS_MAP = {
calliope_base.ReleaseTrack.GA: 'v1',
calliope_base.ReleaseTrack.ALPHA: 'v1alpha',
calliope_base.ReleaseTrack.BETA: 'v1beta'
}
def GetApiClientForTrack(release_track):
"""Retrieves a client based on the release track.
The API clients override the base class for each track so that methods with
functional differences can be overridden. The ssl-certificates api does not
have API changes for alpha, but output is formatted differently, so the alpha
override simply calls the new API.
Args:
release_track: calliope_base.ReleaseTrack, the release track of the command
Returns:
A client that calls appengine using the v1beta or v1alpha API.
"""
api_version = SSL_VERSIONS_MAP[release_track]
return AppengineSslApiClient.GetApiClient(api_version)
class AppengineSslApiClient(base.AppengineApiClientBase):
"""Client used by gcloud to communicate with the App Engine SSL APIs."""
def __init__(self, client):
base.AppengineApiClientBase.__init__(self, client)
self._registry = resources.REGISTRY.Clone()
# pylint: disable=protected-access
self._registry.RegisterApiByName('appengine', client._VERSION)
def CreateSslCertificate(self, display_name, cert_path, private_key_path):
"""Creates a certificate for the given application.
Args:
display_name: str, the display name for the new certificate.
cert_path: str, location on disk to a certificate file.
private_key_path: str, location on disk to a private key file.
Returns:
The created AuthorizedCertificate object.
Raises:
Error if the file does not exist or can't be opened/read.
"""
certificate_data = files.ReadFileContents(cert_path)
private_key_data = files.ReadFileContents(private_key_path)
cert = self.messages.CertificateRawData(
privateKey=private_key_data, publicCertificate=certificate_data)
auth_cert = self.messages.AuthorizedCertificate(
displayName=display_name, certificateRawData=cert)
request = self.messages.AppengineAppsAuthorizedCertificatesCreateRequest(
parent=self._FormatApp(), authorizedCertificate=auth_cert)
return self.client.apps_authorizedCertificates.Create(request)
def DeleteSslCertificate(self, cert_id):
"""Deletes an authorized certificate for the given application.
Args:
cert_id: str, the id of the certificate to delete.
"""
request = self.messages.AppengineAppsAuthorizedCertificatesDeleteRequest(
name=self._FormatSslCert(cert_id))
self.client.apps_authorizedCertificates.Delete(request)
def GetSslCertificate(self, cert_id):
"""Gets a certificate for the given application.
Args:
cert_id: str, the id of the certificate to retrieve.
Returns:
The retrieved AuthorizedCertificate object.
"""
request = self.messages.AppengineAppsAuthorizedCertificatesGetRequest(
name=self._FormatSslCert(cert_id),
view=(self.messages.AppengineAppsAuthorizedCertificatesGetRequest.
ViewValueValuesEnum.FULL_CERTIFICATE))
return self.client.apps_authorizedCertificates.Get(request)
def ListSslCertificates(self):
"""Lists all authorized certificates for the given application.
Returns:
A list of AuthorizedCertificate objects.
"""
request = self.messages.AppengineAppsAuthorizedCertificatesListRequest(
parent=self._FormatApp())
response = self.client.apps_authorizedCertificates.List(request)
return response.certificates
def UpdateSslCertificate(self,
cert_id,
display_name=None,
cert_path=None,
private_key_path=None):
"""Updates a certificate for the given application.
One of display_name, cert_path, or private_key_path should be set. Omitted
fields will not be updated from their current value. Any invalid arguments
will fail the entire command.
Args:
cert_id: str, the id of the certificate to update.
display_name: str, the display name for a new certificate.
cert_path: str, location on disk to a certificate file.
private_key_path: str, location on disk to a private key file.
Returns:
The created AuthorizedCertificate object.
Raises: InvalidInputError if the user does not specify both cert and key.
"""
if bool(cert_path) ^ bool(private_key_path):
missing_arg = '--certificate' if not cert_path else '--private-key'
raise exceptions.RequiredArgumentException(
missing_arg,
'The certificate and the private key must both be updated together.')
mask_fields = []
if display_name:
mask_fields.append('displayName')
cert_data = None
if cert_path and private_key_path:
certificate = files.ReadFileContents(cert_path)
private_key = files.ReadFileContents(private_key_path)
cert_data = self.messages.CertificateRawData(
privateKey=private_key, publicCertificate=certificate)
mask_fields.append('certificateRawData')
auth_cert = self.messages.AuthorizedCertificate(
displayName=display_name, certificateRawData=cert_data)
if not mask_fields:
raise exceptions.MinimumArgumentException([
'--certificate', '--private-key', '--display-name'
], 'Please specify at least one attribute to the certificate update.')
request = self.messages.AppengineAppsAuthorizedCertificatesPatchRequest(
name=self._FormatSslCert(cert_id),
authorizedCertificate=auth_cert,
updateMask=','.join(mask_fields))
return self.client.apps_authorizedCertificates.Patch(request)
def _FormatSslCert(self, cert_id):
res = self._registry.Parse(
cert_id,
params={'appsId': self.project},
collection='appengine.apps.authorizedCertificates')
return res.RelativeName()

View File

@@ -0,0 +1,983 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating a client to talk to the App Engine Admin API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
import json
import operator
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.app import build as app_cloud_build
from googlecloudsdk.api_lib.app import env
from googlecloudsdk.api_lib.app import exceptions
from googlecloudsdk.api_lib.app import instances_util
from googlecloudsdk.api_lib.app import operations_util
from googlecloudsdk.api_lib.app import region_util
from googlecloudsdk.api_lib.app import service_util
from googlecloudsdk.api_lib.app import util
from googlecloudsdk.api_lib.app import version_util
from googlecloudsdk.api_lib.app.api import appengine_api_client_base
from googlecloudsdk.api_lib.cloudbuild import logs as cloudbuild_logs
from googlecloudsdk.appengine.admin.tools.conversion import convert_yaml
from googlecloudsdk.appengine.api import appinfo
from googlecloudsdk.calliope import base as calliope_base
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six
from six.moves import filter # pylint: disable=redefined-builtin
from six.moves import map # pylint: disable=redefined-builtin
APPENGINE_VERSIONS_MAP = {
calliope_base.ReleaseTrack.GA: 'v1',
calliope_base.ReleaseTrack.ALPHA: 'v1alpha',
calliope_base.ReleaseTrack.BETA: 'v1beta'
}
def GetApiClientForTrack(release_track):
api_version = APPENGINE_VERSIONS_MAP[release_track]
return AppengineApiClient.GetApiClient(api_version)
gen1_runtimes = ['python27']
class AppengineApiClient(appengine_api_client_base.AppengineApiClientBase):
"""Client used by gcloud to communicate with the App Engine API."""
def GetApplication(self):
"""Retrieves the application resource.
Returns:
An app resource representing the project's app.
Raises:
apitools_exceptions.HttpNotFoundError if app doesn't exist
"""
request = self.messages.AppengineAppsGetRequest(name=self._FormatApp())
return self.client.apps.Get(request)
def ListRuntimes(self, environment):
"""Lists the available runtimes for the given App Engine environment.
Args:
environment: The environment for the application, either Standard or
Flexible.
Returns:
v1beta|v1.ListRuntimesResponse, the list of Runtimes.
Raises:
apitools_exceptions.HttpNotFoundError if app doesn't exist
"""
request = self.messages.AppengineAppsListRuntimesRequest(
parent=self._FormatApp(), environment=environment
)
return self.client.apps.ListRuntimes(request)
def IsStopped(self, app):
"""Checks application resource to get serving status.
Args:
app: appengine_v1_messages.Application, the application to check.
Returns:
bool, whether the application is currently disabled. If serving or not
set, returns False.
"""
stopped = app.servingStatus in [
self.messages.Application.ServingStatusValueValuesEnum.USER_DISABLED,
self.messages.Application.ServingStatusValueValuesEnum.SYSTEM_DISABLED]
return stopped
def RepairApplication(self, progress_message=None):
"""Creates missing app resources.
In particular, the Application.code_bucket GCS reference.
Args:
progress_message: str, the message to use while the operation is polled,
if not the default.
Returns:
A long running operation.
"""
request = self.messages.AppengineAppsRepairRequest(
name=self._FormatApp(),
repairApplicationRequest=self.messages.RepairApplicationRequest())
operation = self.client.apps.Repair(request)
log.debug('Received operation: [{operation}]'.format(
operation=operation.name))
return operations_util.WaitForOperation(
self.client.apps_operations, operation, message=progress_message)
def CreateApp(self, location, service_account=None, ssl_policy=None):
"""Creates an App Engine app within the current cloud project.
Creates a new singleton app within the currently selected Cloud Project.
The action is one-time and irreversible.
Args:
location: str, The location (region) of the app, i.e. "us-central"
service_account: str, The app level service account of the app, i.e.
"123@test-app.iam.gserviceaccount.com"
ssl_policy: enum, the app-level SSL policy to update for this App Engine
app. Can be DEFAULT or MODERN.
Raises:
apitools_exceptions.HttpConflictError if app already exists
Returns:
A long running operation.
"""
create_request = None
if service_account:
create_request = self.messages.Application(
id=self.project, locationId=location, serviceAccount=service_account)
else:
create_request = self.messages.Application(
id=self.project, locationId=location)
if ssl_policy:
create_request.sslPolicy = ssl_policy
operation = self.client.apps.Create(create_request)
log.debug('Received operation: [{operation}]'.format(
operation=operation.name))
message = ('Creating App Engine application in project [{project}] and '
'region [{region}].'.format(project=self.project,
region=location))
return operations_util.WaitForOperation(self.client.apps_operations,
operation, message=message)
def DeployService(self,
service_name,
version_id,
service_config,
manifest,
build,
extra_config_settings=None,
service_account_email=None):
"""Updates and deploys new app versions.
Args:
service_name: str, The service to deploy.
version_id: str, The version of the service to deploy.
service_config: AppInfoExternal, Service info parsed from a service yaml
file.
manifest: Dictionary mapping source files to Google Cloud Storage
locations.
build: BuildArtifact, a wrapper which contains either the build
ID for an in-progress parallel build, the name of the container image
for a serial build, or the options for creating a build elsewhere. Not
present during standard deploys.
extra_config_settings: dict, client config settings to pass to the server
as beta settings.
service_account_email: Identity of this deployed version. If not set, the
Admin API will fall back to use the App Engine default appspot service
account.
Returns:
The Admin API Operation, unfinished.
Raises:
apitools_exceptions.HttpNotFoundError if build ID doesn't exist
"""
operation = self._CreateVersion(service_name, version_id, service_config,
manifest, build, extra_config_settings,
service_account_email)
message = 'Updating service [{service}]'.format(service=service_name)
if service_config.env in [env.FLEX, env.MANAGED_VMS]:
message += ' (this may take several minutes)'
operation_metadata_type = self._ResolveMetadataType()
# This indicates that a server-side build should be created.
if build and build.IsBuildOptions():
if not operation_metadata_type:
log.warning('Unable to determine build from Operation metadata. '
'Skipping log streaming')
else:
# Poll the operation until the build is present.
poller = operations_util.AppEngineOperationBuildPoller(
self.client.apps_operations, operation_metadata_type)
operation = operations_util.WaitForOperation(
self.client.apps_operations, operation, message=message,
poller=poller)
build_id = operations_util.GetBuildFromOperation(
operation, operation_metadata_type)
if build_id:
build = app_cloud_build.BuildArtifact.MakeBuildIdArtifact(build_id)
if build and build.IsBuildId():
try:
build_ref = resources.REGISTRY.Parse(
build.identifier,
params={'projectId': properties.VALUES.core.project.GetOrFail},
collection='cloudbuild.projects.builds')
cloudbuild_logs.CloudBuildClient().Stream(build_ref, out=log.status)
except apitools_exceptions.HttpNotFoundError:
region = util.ConvertToCloudRegion(self.GetApplication().locationId)
build_ref = resources.REGISTRY.Create(
collection='cloudbuild.projects.locations.builds',
projectsId=properties.VALUES.core.project.GetOrFail,
locationsId=region,
buildsId=build.identifier)
cloudbuild_logs.CloudBuildClient().Stream(build_ref, out=log.status)
done_poller = operations_util.AppEngineOperationPoller(
self.client.apps_operations, operation_metadata_type)
return operations_util.WaitForOperation(
self.client.apps_operations,
operation,
poller=done_poller)
def _ResolveMetadataType(self):
"""Attempts to resolve the expected type for the operation metadata."""
# pylint: disable=protected-access
# TODO(b/74075874): Update ApiVersion method to accurately reflect client.
metadata_type_name = 'OperationMetadata' + self.client._VERSION.title()
# pylint: enable=protected-access
return getattr(self.messages, metadata_type_name)
def _CreateVersion(self,
service_name,
version_id,
service_config,
manifest,
build,
extra_config_settings=None,
service_account_email=None):
"""Begins the updates and deployment of new app versions.
Args:
service_name: str, The service to deploy.
version_id: str, The version of the service to deploy.
service_config: AppInfoExternal, Service info parsed from a service yaml
file.
manifest: Dictionary mapping source files to Google Cloud Storage
locations.
build: BuildArtifact, a wrapper which contains either the build ID for an
in-progress parallel build, the name of the container image for a serial
build, or the options to pass to Appengine for a server-side build.
extra_config_settings: dict, client config settings to pass to the server
as beta settings.
service_account_email: Identity of this deployed version. If not set, the
Admin API will fall back to use the App Engine default appspot service
account.
Returns:
The Admin API Operation, unfinished.
"""
version_resource = self._CreateVersionResource(service_config, manifest,
version_id, build,
extra_config_settings,
service_account_email)
create_request = self.messages.AppengineAppsServicesVersionsCreateRequest(
parent=self._GetServiceRelativeName(service_name=service_name),
version=version_resource)
return self.client.apps_services_versions.Create(create_request)
def GetServiceResource(self, service):
"""Describe the given service.
Args:
service: str, the ID of the service
Returns:
Service resource object from the API
"""
request = self.messages.AppengineAppsServicesGetRequest(
name=self._GetServiceRelativeName(service))
return self.client.apps_services.Get(request)
def SetDefaultVersion(self, service_name, version_id):
"""Sets the default serving version of the given services.
Args:
service_name: str, The service name
version_id: str, The version to set as default.
Returns:
Long running operation.
"""
# Create a traffic split where 100% of traffic goes to the specified
# version.
allocations = {version_id: 1.0}
return self.SetTrafficSplit(service_name, allocations)
def SetTrafficSplit(self, service_name, allocations,
shard_by='UNSPECIFIED', migrate=False):
"""Sets the traffic split of the given services.
Args:
service_name: str, The service name
allocations: A dict mapping version ID to traffic split.
shard_by: A ShardByValuesEnum value specifying how to shard the traffic.
migrate: Whether or not to migrate traffic.
Returns:
Long running operation.
"""
# Create a traffic split where 100% of traffic goes to the specified
# version.
traffic_split = encoding.PyValueToMessage(self.messages.TrafficSplit,
{'allocations': allocations,
'shardBy': shard_by})
update_service_request = self.messages.AppengineAppsServicesPatchRequest(
name=self._GetServiceRelativeName(service_name=service_name),
service=self.messages.Service(split=traffic_split),
migrateTraffic=migrate,
updateMask='split')
message = 'Setting traffic split for service [{service}]'.format(
service=service_name)
operation = self.client.apps_services.Patch(update_service_request)
return operations_util.WaitForOperation(self.client.apps_operations,
operation,
message=message)
def SetIngressTrafficAllowed(self, service_name, ingress_traffic_allowed):
"""Sets the ingress traffic allowed for a service.
Args:
service_name: str, The service name
ingress_traffic_allowed: An IngressTrafficAllowed enum.
Returns:
The completed Operation. The Operation will contain a Service resource.
"""
network_settings = self.messages.NetworkSettings(
ingressTrafficAllowed=ingress_traffic_allowed)
update_service_request = self.messages.AppengineAppsServicesPatchRequest(
name=self._GetServiceRelativeName(service_name=service_name),
service=self.messages.Service(networkSettings=network_settings),
updateMask='networkSettings')
message = 'Setting ingress settings for service [{service}]'.format(
service=service_name)
operation = self.client.apps_services.Patch(update_service_request)
return operations_util.WaitForOperation(
self.client.apps_operations, operation, message=message)
def DeleteVersion(self, service_name, version_id):
"""Deletes the specified version of the given service.
Args:
service_name: str, The service name
version_id: str, The version to delete.
Returns:
The completed Operation.
"""
delete_request = self.messages.AppengineAppsServicesVersionsDeleteRequest(
name=self._FormatVersion(service_name=service_name,
version_id=version_id))
operation = self.client.apps_services_versions.Delete(delete_request)
message = 'Deleting [{0}/{1}]'.format(service_name, version_id)
return operations_util.WaitForOperation(
self.client.apps_operations, operation, message=message)
def SetServingStatus(self, service_name, version_id, serving_status,
block=True):
"""Sets the serving status of the specified version.
Args:
service_name: str, The service name
version_id: str, The version to delete.
serving_status: The serving status to set.
block: bool, whether to block on the completion of the operation
Returns:
The completed Operation if block is True, or the Operation to wait on
otherwise.
"""
patch_request = self.messages.AppengineAppsServicesVersionsPatchRequest(
name=self._FormatVersion(service_name=service_name,
version_id=version_id),
version=self.messages.Version(servingStatus=serving_status),
updateMask='servingStatus')
operation = self.client.apps_services_versions.Patch(patch_request)
if block:
return operations_util.WaitForOperation(self.client.apps_operations,
operation)
else:
return operation
def ListInstances(self, versions):
"""Produces a generator of all instances for the given versions.
Args:
versions: list of version_util.Version
Returns:
A list of instances_util.Instance objects for the given versions
"""
instances = []
for version in versions:
request = self.messages.AppengineAppsServicesVersionsInstancesListRequest(
parent=self._FormatVersion(version.service, version.id))
try:
for instance in list_pager.YieldFromList(
self.client.apps_services_versions_instances,
request,
field='instances',
batch_size=100, # Set batch size so tests can expect it.
batch_size_attribute='pageSize'):
instances.append(
instances_util.Instance.FromInstanceResource(instance))
except apitools_exceptions.HttpNotFoundError:
# Drop versions that were presumed deleted since initial enumeration.
pass
return instances
def GetAllInstances(self, service=None, version=None, version_filter=None):
"""Generator of all instances, optionally filtering by service or version.
Args:
service: str, the ID of the service to filter by.
version: str, the ID of the version to filter by.
version_filter: filter function accepting version_util.Version
Returns:
generator of instance_util.Instance
"""
services = self.ListServices()
log.debug('All services: {0}'.format(services))
services = service_util.GetMatchingServices(
services, [service] if service else None)
versions = self.ListVersions(services)
log.debug('Versions: {0}'.format(list(map(str, versions))))
versions = version_util.GetMatchingVersions(
versions, [version] if version else None, service)
versions = list(filter(version_filter, versions))
return self.ListInstances(versions)
def DebugInstance(self, res, ssh_key=None):
"""Enable debugging of a Flexible instance.
Args:
res: A googleclousdk.core.Resource object.
ssh_key: str, Public SSH key to add to the instance. Examples:
`[USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME]` ,
`[USERNAME]:ssh-rsa [KEY_VALUE] google-ssh {"userName":"[USERNAME]",`
`"expireOn":"[EXPIRE_TIME]"}`
For more information, see Adding and Removing SSH Keys
(https://cloud.google.com/compute/docs/instances/adding-removing-ssh-
keys).
Returns:
The completed Operation.
"""
request = self.messages.AppengineAppsServicesVersionsInstancesDebugRequest(
name=res.RelativeName(),
debugInstanceRequest=self.messages.DebugInstanceRequest(sshKey=ssh_key))
operation = self.client.apps_services_versions_instances.Debug(request)
return operations_util.WaitForOperation(self.client.apps_operations,
operation)
def DeleteInstance(self, res):
"""Delete a Flexible instance.
Args:
res: A googlecloudsdk.core.Resource object.
Returns:
The completed Operation.
"""
request = self.messages.AppengineAppsServicesVersionsInstancesDeleteRequest(
name=res.RelativeName())
operation = self.client.apps_services_versions_instances.Delete(request)
return operations_util.WaitForOperation(self.client.apps_operations,
operation)
def GetInstanceResource(self, res):
"""Describe the given instance of the given version of the given service.
Args:
res: A googlecloudsdk.core.Resource object.
Raises:
apitools_exceptions.HttpNotFoundError: If instance does not
exist.
Returns:
Version resource object from the API
"""
request = self.messages.AppengineAppsServicesVersionsInstancesGetRequest(
name=res.RelativeName())
return self.client.apps_services_versions_instances.Get(request)
def StopVersion(self, service_name, version_id, block=True):
"""Stops the specified version.
Args:
service_name: str, The service name
version_id: str, The version to stop.
block: bool, whether to block on the completion of the operation
Returns:
The completed Operation if block is True, or the Operation to wait on
otherwise.
"""
return self.SetServingStatus(
service_name,
version_id,
self.messages.Version.ServingStatusValueValuesEnum.STOPPED,
block)
def StartVersion(self, service_name, version_id, block=True):
"""Starts the specified version.
Args:
service_name: str, The service name
version_id: str, The version to start.
block: bool, whether to block on the completion of the operation
Returns:
The completed Operation if block is True, or the Operation to wait on
otherwise.
"""
return self.SetServingStatus(
service_name,
version_id,
self.messages.Version.ServingStatusValueValuesEnum.SERVING,
block)
def ListServices(self):
"""Lists all services for the given application.
Returns:
A list of service_util.Service objects.
"""
request = self.messages.AppengineAppsServicesListRequest(
parent=self._FormatApp())
services = []
for service in list_pager.YieldFromList(
self.client.apps_services, request, field='services',
batch_size=100, batch_size_attribute='pageSize'):
traffic_split = {}
if service.split:
for split in service.split.allocations.additionalProperties:
traffic_split[split.key] = split.value
services.append(
service_util.Service(self.project, service.id, traffic_split))
return services
def GetVersionResource(self, service, version):
"""Describe the given version of the given service.
Args:
service: str, the ID of the service for the version to describe.
version: str, the ID of the version to describe.
Returns:
Version resource object from the API.
"""
request = self.messages.AppengineAppsServicesVersionsGetRequest(
name=self._FormatVersion(service, version),
view=(self.messages.
AppengineAppsServicesVersionsGetRequest.ViewValueValuesEnum.FULL))
return self.client.apps_services_versions.Get(request)
def ListVersions(self, services):
"""Lists all versions for the specified services.
Args:
services: A list of service_util.Service objects.
Returns:
A list of version_util.Version objects.
"""
versions = []
for service in services:
# Get the versions.
request = self.messages.AppengineAppsServicesVersionsListRequest(
parent=self._GetServiceRelativeName(service.id))
try:
for version in list_pager.YieldFromList(
self.client.apps_services_versions,
request,
field='versions',
batch_size=100,
batch_size_attribute='pageSize'):
versions.append(
version_util.Version.FromVersionResource(version, service))
except apitools_exceptions.HttpNotFoundError:
# Drop services that were presumed deleted since initial enumeration.
pass
return versions
def ListRegions(self):
"""List all regions for the project, and support for standard and flexible.
Returns:
List of region_util.Region instances for the project.
"""
request = self.messages.AppengineAppsLocationsListRequest(
name='apps/{0}'.format(self.project))
regions = list_pager.YieldFromList(
self.client.apps_locations, request, field='locations',
batch_size=100, batch_size_attribute='pageSize')
return [region_util.Region.FromRegionResource(loc) for loc in regions]
def DeleteService(self, service_name):
"""Deletes the specified service.
Args:
service_name: str, Name of the service to delete.
Returns:
The completed Operation.
"""
delete_request = self.messages.AppengineAppsServicesDeleteRequest(
name=self._GetServiceRelativeName(service_name=service_name))
operation = self.client.apps_services.Delete(delete_request)
message = 'Deleting [{}]'.format(service_name)
return operations_util.WaitForOperation(self.client.apps_operations,
operation,
message=message)
def GetOperation(self, op_id):
"""Grabs details about a particular gcloud operation.
Args:
op_id: str, ID of operation.
Returns:
Operation resource object from API call.
"""
request = self.messages.AppengineAppsOperationsGetRequest(
name=self._FormatOperation(op_id))
return self.client.apps_operations.Get(request)
def ListOperations(self, op_filter=None):
"""Lists all operations for the given application.
Args:
op_filter: String to filter which operations to grab.
Returns:
A list of opeartion_util.Operation objects.
"""
request = self.messages.AppengineAppsOperationsListRequest(
name=self._FormatApp(),
filter=op_filter)
operations = list_pager.YieldFromList(
self.client.apps_operations, request, field='operations',
batch_size=100, batch_size_attribute='pageSize')
return [operations_util.Operation(op) for op in operations]
def _CreateVersionResource(self,
service_config,
manifest,
version_id,
build,
extra_config_settings=None,
service_account_email=None):
"""Constructs a Version resource for deployment.
Args:
service_config: ServiceYamlInfo, Service info parsed from a service yaml
file.
manifest: Dictionary mapping source files to Google Cloud Storage
locations.
version_id: str, The version of the service.
build: BuildArtifact, The build ID, image path, or build options.
extra_config_settings: dict, client config settings to pass to the server
as beta settings.
service_account_email: identity of this deployed version. If not set,
Admin API will fallback to use the App Engine default appspot SA.
Returns:
A Version resource whose Deployment includes either a container pointing
to a completed image, or a build pointing to an in-progress build.
"""
config_dict = copy.deepcopy(service_config.parsed.ToDict())
# We always want to set a value for entrypoint when sending the request
# to Zeus, even if one wasn't specified in the yaml file
if 'entrypoint' not in config_dict:
config_dict['entrypoint'] = ''
if (
'app_engine_apis' in config_dict
and 'app_engine_bundled_services' in config_dict
):
raise exceptions.ConfigError(
'Cannot specify both `app_engine_apis` and '
'`app_engine_bundled_services` in the same `app.yaml` file.'
)
try:
# pylint: disable=protected-access
schema_parser = convert_yaml.GetSchemaParser(self.client._VERSION)
json_version_resource = schema_parser.ConvertValue(config_dict)
except ValueError as e:
raise exceptions.ConfigError(
'[{f}] could not be converted to the App Engine configuration '
'format for the following reason: {msg}'.format(
f=service_config.file, msg=six.text_type(e)))
log.debug('Converted YAML to JSON: "{0}"'.format(
json.dumps(json_version_resource, indent=2, sort_keys=True)))
# Override the 'service_account' in app.yaml if CLI provided this param.
if service_account_email is not None:
json_version_resource['serviceAccount'] = service_account_email
json_version_resource['deployment'] = {}
# Add the deployment manifest information.
json_version_resource['deployment']['files'] = manifest
if build:
if build.IsImage():
json_version_resource['deployment']['container'] = {
'image': build.identifier
}
elif build.IsBuildId():
json_version_resource['deployment']['build'] = {
'cloudBuildId': build.identifier
}
elif build.IsBuildOptions():
json_version_resource['deployment']['cloudBuildOptions'] = (
build.identifier)
version_resource = encoding.PyValueToMessage(self.messages.Version,
json_version_resource)
# For consistency in the tests:
if version_resource.envVariables:
version_resource.envVariables.additionalProperties.sort(
key=lambda x: x.key)
# We need to pipe some settings to the server as beta settings.
if extra_config_settings:
if 'betaSettings' not in json_version_resource:
json_version_resource['betaSettings'] = {}
json_version_resource['betaSettings'].update(extra_config_settings)
# In the JSON representation, BetaSettings are a dict of key-value pairs.
# In the Message representation, BetaSettings are an ordered array of
# key-value pairs. Sort the key-value pairs here, so that unit testing is
# possible.
if 'betaSettings' in json_version_resource:
json_dict = json_version_resource.get('betaSettings')
attributes = []
for key, value in sorted(json_dict.items()):
attributes.append(
self.messages.Version.BetaSettingsValue.AdditionalProperty(
key=key, value=value))
version_resource.betaSettings = self.messages.Version.BetaSettingsValue(
additionalProperties=attributes)
# Add the app engine bundled services to the version resource.
if 'appEngineBundledServices' in json_version_resource:
bundled_services_enums = []
for service_name in sorted(
json_version_resource['appEngineBundledServices']
):
enum_value = service_name.upper()
log.debug('enum_value: %s', enum_value)
try:
bundled_services_enums.append(
getattr(
self.messages.Version.AppEngineBundledServicesValueListEntryValuesEnum,
enum_value,
)
)
except AttributeError:
raise appinfo.validation.ValidationError(
f'Invalid bundled service: {service_name}.'
)
if bundled_services_enums:
log.debug(
'Bundled services enums: %s', bundled_services_enums
)
version_resource.appEngineBundledServices = bundled_services_enums
log.debug(
'version_resource.appEngineBundledServices: %s',
version_resource.appEngineBundledServices,
)
# The files in the deployment manifest also need to be sorted for unit
# testing purposes.
try:
version_resource.deployment.files.additionalProperties.sort(
key=operator.attrgetter('key')
)
except AttributeError: # manifest not present, or no files in manifest
pass
# Add an ID for the version which is to be created.
version_resource.id = version_id
return version_resource
def UpdateDispatchRules(self, dispatch_rules):
"""Updates an application's dispatch rules.
Args:
dispatch_rules: [{'service': str, 'domain': str, 'path': str}], dispatch-
rules to set-and-replace.
Returns:
Long running operation.
"""
# Create a configuration update request.
update_mask = 'dispatchRules,'
application_update = self.messages.Application()
application_update.dispatchRules = [self.messages.UrlDispatchRule(**r)
for r in dispatch_rules]
update_request = self.messages.AppengineAppsPatchRequest(
name=self._FormatApp(),
application=application_update,
updateMask=update_mask)
operation = self.client.apps.Patch(update_request)
log.debug('Received operation: [{operation}] with mask [{mask}]'.format(
operation=operation.name,
mask=update_mask))
return operations_util.WaitForOperation(self.client.apps_operations,
operation)
def UpdateDatabaseType(self, database_type):
"""Updates an application's database_type.
Args:
database_type: New database type to switch to
Returns:
Long running operation.
"""
# Create a configuration update request.
update_mask = 'databaseType'
application_update = self.messages.Application()
application_update.databaseType = database_type
update_request = self.messages.AppengineAppsPatchRequest(
name=self._FormatApp(),
application=application_update,
updateMask=update_mask)
operation = self.client.apps.Patch(update_request)
log.debug('Received operation: [{operation}] with mask [{mask}]'.format(
operation=operation.name, mask=update_mask))
return operations_util.WaitForOperation(self.client.apps_operations,
operation)
def CheckGen1AppId(self, service_name, project_id):
"""Checks if the service contains a Gen1 app.
Args:
service_name: str, The service name
project_id: str, The project id
Returns:
boolean, True if the service contains a Gen1 app, False otherwise
"""
request = self.messages.AppengineAppsServicesMigrationCheckGen1appIdRequest(
name=self._GetServiceRelativeName(service_name),
checkGen1AppIdRequest=self.messages.CheckGen1AppIdRequest(
projectId=project_id
),
)
return self.client.apps_services_migration.CheckGen1appId(request)
def MigrateConfigYaml(
self, project_id, config_as_string, runtime, service_name
):
"""Migrates the app.yaml file provided by the user to be Gen2 compatible.
Args:
project_id: str, The project id
config_as_string: str, The config as a string
runtime: str, The runtime
service_name: str, The service name
Returns:
str, The migrated config as a string
"""
if runtime in gen1_runtimes:
runtime_enum = (
self.messages.MigrateConfigYamlRequest.RuntimeValueValuesEnum.GEN1_PYTHON27
)
else:
runtime_enum = (
self.messages.MigrateConfigYamlRequest.RuntimeValueValuesEnum.MIGRATION_ASSIST_RUNTIME_UNSPECIFIED
)
req = self.messages.AppengineAppsServicesMigrationMigrateConfigYamlRequest(
name=self._GetServiceRelativeName(service_name),
migrateConfigYamlRequest=self.messages.MigrateConfigYamlRequest(
projectId=project_id,
configAsString=config_as_string,
runtime=runtime_enum,
),
)
return self.client.apps_services_migration.MigrateConfigYaml(req)
def MigrateCodeFile(self, project_id, code_as_string, runtime, service_name):
"""Migrates the code file provided by the user to Gen2 runtime.
Args:
project_id: str, The project id
code_as_string: str, The code as a string
runtime: str, The runtime
service_name: str, The service name
Returns:
Long running operation
"""
if runtime in gen1_runtimes:
runtime_enum = (
self.messages.MigrateCodeFileRequest.RuntimeValueValuesEnum.GEN1_PYTHON27
)
else:
runtime_enum = (
self.messages.MigrateCodeFileRequest.RuntimeValueValuesEnum.MIGRATION_ASSIST_RUNTIME_UNSPECIFIED
)
request = (
self.messages.AppengineAppsServicesMigrationMigrateCodeFileRequest(
name=self._GetServiceRelativeName(service_name),
migrateCodeFileRequest=self.messages.MigrateCodeFileRequest(
projectId=project_id,
codeAsString=code_as_string,
runtime=runtime_enum,
),
)
)
operation = self.client.apps_services_migration.MigrateCodeFile(request)
return operations_util.WaitForOperation(
self.client.apps_operations, operation
)

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility classes for interacting with the Cloud Build API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import enum
from googlecloudsdk.api_lib.cloudbuild import build
class BuildArtifact(object):
"""Represents a build of a flex container, either in-progress or completed.
A build artifact is either a build_id for an in-progress build, the image
name for a completed container build, or options for the build to be created
elsewhere.
"""
class BuildType(enum.Enum):
IMAGE = 1
BUILD_ID = 2
BUILD_OPTIONS = 3
def __init__(self, build_type, identifier, build_op=None):
self.build_type = build_type
self.identifier = identifier
self.build_op = build_op
def IsImage(self):
return self.build_type == self.BuildType.IMAGE
def IsBuildId(self):
return self.build_type == self.BuildType.BUILD_ID
def IsBuildOptions(self):
return self.build_type == self.BuildType.BUILD_OPTIONS
@classmethod
def MakeBuildIdArtifact(cls, build_id):
return cls(cls.BuildType.BUILD_ID, build_id)
@classmethod
def MakeImageArtifact(cls, image_name):
return cls(cls.BuildType.IMAGE, image_name)
@classmethod
def MakeBuildOptionsArtifact(cls, build_options):
return cls(cls.BuildType.BUILD_OPTIONS, build_options)
@classmethod
def MakeBuildIdArtifactFromOp(cls, build_op):
build_id = build.GetBuildProp(build_op, 'id', required=True)
return cls(cls.BuildType.BUILD_ID, build_id, build_op)
@classmethod
def MakeImageArtifactFromOp(cls, build_op):
"""Create Image BuildArtifact from build operation."""
source = build.GetBuildProp(build_op, 'source')
for prop in source.object_value.properties:
if prop.key == 'storageSource':
for storage_prop in prop.value.object_value.properties:
if storage_prop.key == 'object':
image_name = storage_prop.value.string_value
if image_name is None:
raise build.BuildFailedError('Could not determine image name')
return cls(cls.BuildType.IMAGE, image_name, build_op)

View File

@@ -0,0 +1,260 @@
# -*- coding: utf-8 -*- #
# Copyright 2013 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods to upload source to GCS and call Cloud Build service."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import gzip
import io
import operator
import os
import tarfile
from apitools.base.py import encoding
from googlecloudsdk.api_lib.cloudbuild import cloudbuild_util
from googlecloudsdk.api_lib.storage import storage_api
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import files
from googlecloudsdk.core.util import times
import six
from six.moves import filter # pylint: disable=redefined-builtin
# Paths that shouldn't be ignored client-side.
# Behavioral parity with github.com/docker/docker-py.
BLOCKLISTED_DOCKERIGNORE_PATHS = ['Dockerfile', '.dockerignore']
def _CreateTar(upload_dir, gen_files, paths, gz):
"""Create tarfile for upload to GCS.
The third-party code closes the tarfile after creating, which does not
allow us to write generated files after calling docker.utils.tar
since gzipped tarfiles can't be opened in append mode.
Args:
upload_dir: the directory to be archived
gen_files: Generated files to write to the tar
paths: allowed paths in the tarfile
gz: gzipped tarfile object
"""
root = os.path.abspath(upload_dir)
t = tarfile.open(mode='w', fileobj=gz)
for path in sorted(paths):
full_path = os.path.join(root, path)
t.add(full_path, arcname=path, recursive=False)
for name, contents in six.iteritems(gen_files):
genfileobj = io.BytesIO(contents.encode())
tar_info = tarfile.TarInfo(name=name)
tar_info.size = len(genfileobj.getvalue())
t.addfile(tar_info, fileobj=genfileobj)
genfileobj.close()
t.close()
def _GetDockerignoreExclusions(upload_dir, gen_files):
"""Helper function to read the .dockerignore on disk or in generated files.
Args:
upload_dir: the path to the root directory.
gen_files: dict of filename to contents of generated files.
Returns:
Set of exclusion expressions from the dockerignore file.
"""
dockerignore = os.path.join(upload_dir, '.dockerignore')
exclude = set()
ignore_contents = None
if os.path.exists(dockerignore):
ignore_contents = files.ReadFileContents(dockerignore)
else:
ignore_contents = gen_files.get('.dockerignore')
if ignore_contents:
# Read the exclusions from the dockerignore, filtering out blank lines.
exclude = set(filter(bool, ignore_contents.splitlines()))
# Remove paths that shouldn't be excluded on the client.
exclude -= set(BLOCKLISTED_DOCKERIGNORE_PATHS)
return exclude
def _GetIncludedPaths(upload_dir, source_files, exclude):
"""Helper function to filter paths in root using dockerignore and skip_files.
We iterate separately to filter on skip_files in order to preserve expected
behavior (standard deployment skips directories if they contain only files
ignored by skip_files).
Args:
upload_dir: the path to the root directory.
source_files: [str], relative paths to upload.
exclude: the .dockerignore file exclusions.
Returns:
Set of paths (relative to upload_dir) to include.
"""
# Import only when necessary, to decrease startup time.
# pylint: disable=g-import-not-at-top
import docker
# This code replicates how docker.utils.tar() finds the root
# and excluded paths.
root = os.path.abspath(upload_dir)
# Get set of all paths other than exclusions from dockerignore.
paths = docker.utils.exclude_paths(root, list(exclude))
# Also filter on the ignore regex from .gcloudignore or skip_files.
paths.intersection_update(source_files)
return paths
def UploadSource(upload_dir, source_files, object_ref, gen_files=None):
"""Upload a gzipped tarball of the source directory to GCS.
Note: To provide parity with docker's behavior, we must respect .dockerignore.
Args:
upload_dir: the directory to be archived.
source_files: [str], relative paths to upload.
object_ref: storage_util.ObjectReference, the Cloud Storage location to
upload the source tarball to.
gen_files: dict of filename to (str) contents of generated config and
source context files.
"""
gen_files = gen_files or {}
dockerignore_contents = _GetDockerignoreExclusions(upload_dir, gen_files)
included_paths = _GetIncludedPaths(
upload_dir, source_files, dockerignore_contents)
# We can't use tempfile.NamedTemporaryFile here because ... Windows.
# See https://bugs.python.org/issue14243. There are small cleanup races
# during process termination that will leave artifacts on the filesystem.
# eg, CTRL-C on windows leaves both the directory and the file. Unavoidable.
# On Posix, `kill -9` has similar behavior, but CTRL-C allows cleanup.
with files.TemporaryDirectory() as temp_dir:
f = files.BinaryFileWriter(os.path.join(temp_dir, 'src.tgz'))
with gzip.GzipFile(mode='wb', fileobj=f) as gz:
_CreateTar(upload_dir, gen_files, included_paths, gz)
f.close()
storage_client = storage_api.StorageClient()
storage_client.CopyFileToGCS(f.name, object_ref)
def GetServiceTimeoutSeconds(timeout_property_str):
"""Returns the service timeout in seconds given the duration string."""
if timeout_property_str is None:
return None
build_timeout_duration = times.ParseDuration(timeout_property_str,
default_suffix='s')
return int(build_timeout_duration.total_seconds)
def GetServiceTimeoutString(timeout_property_str):
"""Returns the service timeout duration string with suffix appended."""
if timeout_property_str is None:
return None
build_timeout_secs = GetServiceTimeoutSeconds(timeout_property_str)
return six.text_type(build_timeout_secs) + 's'
class InvalidBuildError(ValueError):
"""Error indicating that ExecuteCloudBuild was given a bad Build message."""
def __init__(self, field):
super(InvalidBuildError, self).__init__(
'Field [{}] was provided, but should not have been. '
'You may be using an improper Cloud Build pipeline.'.format(field))
def _ValidateBuildFields(build, fields):
"""Validates that a Build message doesn't have fields that we populate."""
for field in fields:
if getattr(build, field, None) is not None:
raise InvalidBuildError(field)
def GetDefaultBuild(output_image):
"""Get the default build for this runtime.
This build just uses the latest docker builder image (location pulled from the
app/container_builder_image property) to run a `docker build` with the given
tag.
Args:
output_image: GCR location for the output docker image (e.g.
`gcr.io/test-gae/hardcoded-output-tag`)
Returns:
Build, a CloudBuild Build message with the given steps (ready to be given to
FixUpBuild).
"""
messages = cloudbuild_util.GetMessagesModule()
builder = properties.VALUES.app.container_builder_image.Get()
log.debug('Using builder image: [{0}]'.format(builder))
return messages.Build(
steps=[messages.BuildStep(name=builder,
args=['build', '-t', output_image, '.'])],
images=[output_image])
def FixUpBuild(build, object_ref):
"""Return a modified Build object with run-time values populated.
Specifically:
- `source` is pulled from the given object_ref
- `timeout` comes from the app/cloud_build_timeout property
- `logsBucket` uses the bucket from object_ref
Args:
build: cloudbuild Build message. The Build to modify. Fields 'timeout',
'source', and 'logsBucket' will be added and may not be given.
object_ref: storage_util.ObjectReference, the Cloud Storage location of the
source tarball.
Returns:
Build, (copy) of the given Build message with the specified fields
populated.
Raises:
InvalidBuildError: if the Build message had one of the fields this function
sets pre-populated
"""
messages = cloudbuild_util.GetMessagesModule()
# Make a copy, so we don't modify the original
build = encoding.CopyProtoMessage(build)
# CopyProtoMessage doesn't preserve the order of additionalProperties; sort
# these so that they're in a consistent order for tests (this *only* matters
# for tests).
if build.substitutions:
build.substitutions.additionalProperties.sort(
key=operator.attrgetter('key'))
# Check that nothing we're expecting to fill in has been set already
_ValidateBuildFields(build, ('source', 'timeout', 'logsBucket'))
build.timeout = GetServiceTimeoutString(
properties.VALUES.app.cloud_build_timeout.Get())
build.logsBucket = object_ref.bucket
build.source = messages.Source(
storageSource=messages.StorageSource(
bucket=object_ref.bucket,
object=object_ref.name,
),
)
return build

View File

@@ -0,0 +1,318 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods used by the deploy_app command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import datetime
import hashlib
import os
from apitools.base.py import exceptions as apitools_exceptions
from googlecloudsdk.api_lib.app import metric_names
from googlecloudsdk.api_lib.storage import storage_api
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.appengine.tools import context_util
from googlecloudsdk.command_lib.storage import storage_parallel
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import metrics
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import encoding
from googlecloudsdk.core.util import files as file_utils
from googlecloudsdk.core.util import times
from six.moves import map # pylint: disable=redefined-builtin
_DEFAULT_NUM_THREADS = 8
# TTL expiry margin, to compensate for incorrect local time and timezone,
# as well as deployment time.
_TTL_MARGIN = datetime.timedelta(1)
class LargeFileError(core_exceptions.Error):
def __init__(self, path, size, max_size):
super(LargeFileError, self).__init__(
('Cannot upload file [{path}], which has size [{size}] (greater than '
'maximum allowed size of [{max_size}]). Please delete the file or add '
'to the skip_files entry in your application .yaml file and try '
'again.'.format(path=path, size=size, max_size=max_size)))
class MultiError(core_exceptions.Error):
def __init__(self, operation_description, errors):
if len(errors) > 1:
msg = 'Multiple errors occurred {0}\n'.format(operation_description)
else:
msg = 'An error occurred {0}\n'.format(operation_description)
errors_string = '\n\n'.join(map(str, errors))
super(core_exceptions.Error, self).__init__(msg + errors_string)
self.errors = errors
def _BuildDeploymentManifest(upload_dir, source_files, bucket_ref, tmp_dir):
"""Builds a deployment manifest for use with the App Engine Admin API.
Args:
upload_dir: str, path to the service's upload directory
source_files: [str], relative paths to upload.
bucket_ref: The reference to the bucket files will be placed in.
tmp_dir: A temp directory for storing generated files (currently just source
context files).
Returns:
A deployment manifest (dict) for use with the Admin API.
"""
manifest = {}
bucket_url = 'https://storage.googleapis.com/{0}'.format(bucket_ref.bucket)
# Normal application files.
for rel_path in source_files:
full_path = os.path.join(upload_dir, rel_path)
sha1_hash = file_utils.Checksum.HashSingleFile(full_path,
algorithm=hashlib.sha1)
manifest_path = '/'.join([bucket_url, sha1_hash])
manifest[_FormatForManifest(rel_path)] = {
'sourceUrl': manifest_path,
'sha1Sum': sha1_hash
}
# Source context files. These are temporary files which indicate the current
# state of the source repository (git, cloud repo, etc.)
context_files = context_util.CreateContextFiles(
tmp_dir, None, source_dir=upload_dir)
for context_file in context_files:
rel_path = os.path.basename(context_file)
if rel_path in manifest:
# The source context file was explicitly provided by the user.
log.debug('Source context already exists. Using the existing file.')
continue
else:
sha1_hash = file_utils.Checksum.HashSingleFile(context_file,
algorithm=hashlib.sha1)
manifest_path = '/'.join([bucket_url, sha1_hash])
manifest[_FormatForManifest(rel_path)] = {
'sourceUrl': manifest_path,
'sha1Sum': sha1_hash,
}
return manifest
def _GetLifecycleDeletePolicy(storage_client, bucket_ref):
"""Get the TTL of objects in days as specified by the lifecycle policy.
Only "delete by age" policies are accounted for.
Args:
storage_client: storage_api.StorageClient, API client wrapper.
bucket_ref: The GCS bucket reference.
Returns:
datetime.timedelta, TTL of objects in days, or None if no deletion
policy on the bucket.
"""
try:
bucket = storage_client.client.buckets.Get(
request=storage_client.messages.StorageBucketsGetRequest(
bucket=bucket_ref.bucket),
global_params=storage_client.messages.StandardQueryParameters(
fields='lifecycle'))
except apitools_exceptions.HttpForbiddenError:
return None
if not bucket.lifecycle:
return None
rules = bucket.lifecycle.rule
ages = [
rule.condition.age for rule in rules if rule.condition.age is not None and
rule.condition.age >= 0 and rule.action.type == 'Delete'
]
return datetime.timedelta(min(ages)) if ages else None
def _IsTTLSafe(ttl, obj):
"""Determines whether a GCS object is close to end-of-life.
In order to reduce false negative rate (objects that are close to deletion but
aren't marked as such) the returned filter is forward-adjusted with
_TTL_MARGIN.
Args:
ttl: datetime.timedelta, TTL of objects, or None if no TTL.
obj: storage object to check.
Returns:
True if the ojbect is safe or False if it is approaching end of life.
"""
if ttl is None:
return True
now = times.Now(times.UTC)
delta = ttl - _TTL_MARGIN
return (now - obj.timeCreated) <= delta
def _BuildFileUploadMap(manifest, source_dir, bucket_ref, tmp_dir,
max_file_size):
"""Builds a map of files to upload, indexed by their hash.
This skips already-uploaded files.
Args:
manifest: A dict containing the deployment manifest for a single service.
source_dir: The relative source directory of the service.
bucket_ref: The GCS bucket reference to upload files into.
tmp_dir: The path to a temporary directory where generated files may be
stored. If a file in the manifest is not found in the source directory,
it will be retrieved from this directory instead.
max_file_size: int, File size limit per individual file or None if no limit.
Raises:
LargeFileError: if one of the files to upload exceeds the maximum App Engine
file size.
Returns:
A dict mapping hashes to file paths that should be uploaded.
"""
files_to_upload = {}
storage_client = storage_api.StorageClient()
ttl = _GetLifecycleDeletePolicy(storage_client, bucket_ref)
existing_items = set(o.name for o in storage_client.ListBucket(bucket_ref)
if _IsTTLSafe(ttl, o))
skipped_size, total_size = 0, 0
for rel_path in manifest:
full_path = os.path.join(source_dir, rel_path)
# For generated files, the relative path is based on the tmp_dir rather
# than source_dir. If the file is not in the source directory, look in
# tmp_dir instead.
if not os.path.exists(encoding.Encode(full_path, encoding='utf-8')):
full_path = os.path.join(tmp_dir, rel_path)
# Perform this check when creating the upload map, so we catch too-large
# files that have already been uploaded
size = os.path.getsize(encoding.Encode(full_path, encoding='utf-8'))
if max_file_size and size > max_file_size:
raise LargeFileError(full_path, size, max_file_size)
sha1_hash = manifest[rel_path]['sha1Sum']
total_size += size
if sha1_hash in existing_items:
log.debug('Skipping upload of [{f}]'.format(f=rel_path))
skipped_size += size
else:
files_to_upload[sha1_hash] = full_path
if total_size:
log.info('Incremental upload skipped {pct}% of data'.format(
pct=round(100.0 * skipped_size / total_size, 2)))
return files_to_upload
class FileUploadTask(object):
def __init__(self, sha1_hash, path, bucket_url):
self.sha1_hash = sha1_hash
self.path = path
self.bucket_url = bucket_url
def _UploadFilesThreads(files_to_upload, bucket_ref):
"""Uploads files to App Engine Cloud Storage bucket using threads.
Args:
files_to_upload: dict {str: str}, map of checksum to local path
bucket_ref: storage_api.BucketReference, the reference to the bucket files
will be placed in.
Raises:
MultiError: if one or more errors occurred during file upload.
"""
num_threads = (properties.VALUES.app.num_file_upload_threads.GetInt() or
storage_parallel.DEFAULT_NUM_THREADS)
tasks = []
# Have to sort files because the test framework requires a known order for
# mocked API calls.
for sha1_hash, path in sorted(files_to_upload.items()):
dest_obj_ref = storage_util.ObjectReference.FromBucketRef(bucket_ref,
sha1_hash)
task = storage_parallel.FileUploadTask(path, dest_obj_ref)
tasks.append(task)
storage_parallel.UploadFiles(tasks, num_threads=num_threads,
show_progress_bar=True)
def CopyFilesToCodeBucket(upload_dir, source_files,
bucket_ref, max_file_size=None):
"""Copies application files to the Google Cloud Storage code bucket.
Use the Cloud Storage API using threads.
Consider the following original structure:
app/
main.py
tools/
foo.py
Assume main.py has SHA1 hash 123 and foo.py has SHA1 hash 456. The resultant
GCS bucket will look like this:
gs://$BUCKET/
123
456
The resulting App Engine API manifest will be:
{
"app/main.py": {
"sourceUrl": "https://storage.googleapis.com/staging-bucket/123",
"sha1Sum": "123"
},
"app/tools/foo.py": {
"sourceUrl": "https://storage.googleapis.com/staging-bucket/456",
"sha1Sum": "456"
}
}
A 'list' call of the bucket is made at the start, and files that hash to
values already present in the bucket will not be uploaded again.
Args:
upload_dir: str, path to the service's upload directory
source_files: [str], relative paths to upload.
bucket_ref: The reference to the bucket files will be placed in.
max_file_size: int, File size limit per individual file or None if no limit.
Returns:
A dictionary representing the manifest.
"""
metrics.CustomTimedEvent(metric_names.COPY_APP_FILES_START)
# Collect a list of files to upload, indexed by the SHA so uploads are
# deduplicated.
with file_utils.TemporaryDirectory() as tmp_dir:
manifest = _BuildDeploymentManifest(
upload_dir, source_files, bucket_ref, tmp_dir)
files_to_upload = _BuildFileUploadMap(
manifest, upload_dir, bucket_ref, tmp_dir, max_file_size)
_UploadFilesThreads(files_to_upload, bucket_ref)
log.status.Print('File upload done.')
log.info('Manifest: [{0}]'.format(manifest))
metrics.CustomTimedEvent(metric_names.COPY_APP_FILES)
return manifest
def _FormatForManifest(filename):
"""Reformat a filename for the deployment manifest if it is Windows format."""
if os.path.sep == '\\':
return filename.replace('\\', '/')
return filename

View File

@@ -0,0 +1,678 @@
# -*- coding: utf-8 -*- #
# Copyright 2013 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods used by the deploy command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
import os
import re
from apitools.base.py import exceptions as apitools_exceptions
from gae_ext_runtime import ext_runtime
from googlecloudsdk.api_lib.app import appengine_api_client
from googlecloudsdk.api_lib.app import build as app_build
from googlecloudsdk.api_lib.app import cloud_build
from googlecloudsdk.api_lib.app import docker_image
from googlecloudsdk.api_lib.app import metric_names
from googlecloudsdk.api_lib.app import runtime_builders
from googlecloudsdk.api_lib.app import util
from googlecloudsdk.api_lib.app import yaml_parsing
from googlecloudsdk.api_lib.app.images import config
from googlecloudsdk.api_lib.app.runtimes import fingerprinter
from googlecloudsdk.api_lib.cloudbuild import build as cloudbuild_build
from googlecloudsdk.api_lib.services import enable_api
from googlecloudsdk.api_lib.services import exceptions as s_exceptions
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.api_lib.util import exceptions as api_lib_exceptions
from googlecloudsdk.appengine.api import appinfo
from googlecloudsdk.appengine.tools import context_util
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import metrics
from googlecloudsdk.core import properties
from googlecloudsdk.core.console import progress_tracker
from googlecloudsdk.core.credentials import creds
from googlecloudsdk.core.credentials import store as c_store
from googlecloudsdk.core.util import files
from googlecloudsdk.core.util import platforms
import six
from six.moves import filter # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
DEFAULT_SERVICE = 'default'
ALT_SEPARATOR = '-dot-'
MAX_DNS_LABEL_LENGTH = 63 # http://tools.ietf.org/html/rfc2181#section-11
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247(v=vs.85).aspx
# Technically, this should be 260 because of the drive, ':\', and a null
# terminator, but any time we're getting close we're in dangerous territory.
_WINDOWS_MAX_PATH = 256
# The admin API has a timeout for individual tasks; if the build is greater
# than 10 minutes, it might trigger that timeout, so it's not a candidate for
# parallelized builds.
MAX_PARALLEL_BUILD_TIME = 600
FLEXIBLE_SERVICE_VERIFY_WARNING = (
'Unable to verify that the Appengine Flexible API is enabled for project '
'[{}]. You may not have permission to list enabled services on this '
'project. If it is not enabled, this may cause problems in running your '
'deployment. Please ask the project owner to ensure that the Appengine '
'Flexible API has been enabled and that this account has permission to '
'list enabled APIs.')
FLEXIBLE_SERVICE_VERIFY_WITH_SERVICE_ACCOUNT = (
'Note: When deploying with a service account, the Service Management API '
'needs to be enabled in order to verify that the Appengine Flexible API '
'is enabled. Please ensure the Service Management API has been enabled '
'on this project by the project owner.')
PREPARE_FAILURE_MSG = (
'Enabling the Appengine Flexible API failed on project [{}]. You '
'may not have permission to enable APIs on this project. Please ask '
'the project owner to enable the Appengine Flexible API on this project.')
class Error(exceptions.Error):
"""Base error for this module."""
class PrepareFailureError(Error):
pass
class WindowMaxPathError(Error):
"""Raised if a file cannot be read because of the MAX_PATH limitation."""
_WINDOWS_MAX_PATH_ERROR_TEMPLATE = """\
The following file couldn't be read because its path is too long:
[{0}]
For more information on this issue and possible workarounds, please read the
following (links are specific to Node.js, but the information is generally
applicable):
* https://github.com/Microsoft/nodejstools/issues/69
* https://github.com/Microsoft/nodejs-guidelines/blob/master/windows-environment.md#max_path-explanation-and-workarounds\
"""
def __init__(self, filename):
super(WindowMaxPathError, self).__init__(
self._WINDOWS_MAX_PATH_ERROR_TEMPLATE.format(filename))
class DockerfileError(exceptions.Error):
"""Raised if a Dockerfile was found along with a non-custom runtime."""
class CloudbuildYamlError(exceptions.Error):
"""Raised if a cloudbuild.yaml was found along with a non-custom runtime."""
class CustomRuntimeFilesError(exceptions.Error):
"""Raised if a custom runtime has both a Dockerfile and a cloudbuild.yaml."""
class NoDockerfileError(exceptions.Error):
"""No Dockerfile found."""
class UnsatisfiedRequirementsError(exceptions.Error):
"""Raised if we are unable to detect the runtime."""
def _NeedsDockerfile(info, source_dir):
"""Returns True if the given directory needs a Dockerfile for this app.
A Dockerfile is necessary when there is no Dockerfile in source_dir,
regardless of whether we generate it here on the client-side, or in Cloud
Container Builder server-side.
The reason this function is more complicated than that is that it additionally
verifies the sanity of the provided configuration by raising an exception if:
- The runtime is "custom", but no Dockerfile is present
- The runtime is not "custom", and a Dockerfile or cloudbuild.yaml is present
- The runtime is "custom", and has both a cloudbuild.yaml and a Dockerfile.
(The reason cloudbuild.yaml is tied into this method is that its use should be
mutually exclusive with the Dockerfile.)
Args:
info: (googlecloudsdk.api_lib.app.yaml_parsing.ServiceYamlInfo). The
configuration for the service.
source_dir: str, the path to the service's source directory
Raises:
CloudbuildYamlError: if a cloudbuild.yaml is present, but the runtime is not
"custom".
DockerfileError: if a Dockerfile is present, but the runtime is not
"custom".
NoDockerfileError: Raised if a user didn't supply a Dockerfile and chose a
custom runtime.
CustomRuntimeFilesError: if a custom runtime had both a Dockerfile and a
cloudbuild.yaml file.
Returns:
bool, whether Dockerfile generation is necessary.
"""
has_dockerfile = os.path.exists(
os.path.join(source_dir, config.DOCKERFILE))
has_cloudbuild = os.path.exists(
os.path.join(source_dir, runtime_builders.Resolver.CLOUDBUILD_FILE))
if info.runtime == 'custom':
if has_dockerfile and has_cloudbuild:
raise CustomRuntimeFilesError(
('A custom runtime must have exactly one of [{}] and [{}] in the '
'source directory; [{}] contains both').format(
config.DOCKERFILE, runtime_builders.Resolver.CLOUDBUILD_FILE,
source_dir))
elif has_dockerfile:
log.info('Using %s found in %s', config.DOCKERFILE, source_dir)
return False
elif has_cloudbuild:
log.info('Not using %s because cloudbuild.yaml was found instead.',
config.DOCKERFILE)
return True
else:
raise NoDockerfileError(
'You must provide your own Dockerfile when using a custom runtime. '
'Otherwise provide a "runtime" field with one of the supported '
'runtimes.')
else:
if has_dockerfile:
raise DockerfileError(
'There is a Dockerfile in the current directory, and the runtime '
'field in {0} is currently set to [runtime: {1}]. To use your '
'Dockerfile to build a custom runtime, set the runtime field to '
'[runtime: custom]. To continue using the [{1}] runtime, please '
'remove the Dockerfile from this directory.'.format(info.file,
info.runtime))
elif has_cloudbuild:
raise CloudbuildYamlError(
'There is a cloudbuild.yaml in the current directory, and the '
'runtime field in {0} is currently set to [runtime: {1}]. To use '
'your cloudbuild.yaml to build a custom runtime, set the runtime '
'field to [runtime: custom]. To continue using the [{1}] runtime, '
'please remove the cloudbuild.yaml from this directory.'.format(
info.file, info.runtime))
log.info('Need Dockerfile to be generated for runtime %s', info.runtime)
return True
def ShouldUseRuntimeBuilders(service, strategy, needs_dockerfile):
"""Returns whether we whould use runtime builders for this application build.
If there is no image that needs to be built (service.RequiresImage() ==
False), runtime builders are irrelevant, so they do not need to be built.
If there is an image that needs to be built, whether to use runtime builders
is determined by the RuntimeBuilderStrategy, based on the service runtime and
whether the service being deployed has a Dockerfile already made, or whether
it needs one built.
Args:
service: ServiceYamlInfo, The parsed service config.
strategy: runtime_builders.RuntimeBuilderStrategy, the strategy for
determining whether a runtime should use runtime builders.
needs_dockerfile: bool, whether the Dockerfile in the source directory is
absent.
Returns:
bool, whether to use the runtime builders.
Raises:
ValueError: if an unrecognized runtime_builder_strategy is given
"""
return (service.RequiresImage() and
strategy.ShouldUseRuntimeBuilders(service.runtime, needs_dockerfile))
def _GetDockerfiles(info, dockerfile_dir):
"""Returns map of in-memory Docker-related files to be packaged.
Returns the files in-memory, so that we don't have to drop them on disk;
instead, we include them in the archive sent to App Engine directly.
Args:
info: (googlecloudsdk.api_lib.app.yaml_parsing.ServiceYamlInfo)
The service config.
dockerfile_dir: str, path to the directory to fingerprint and generate
Dockerfiles for.
Raises:
UnsatisfiedRequirementsError: Raised if the code in the directory doesn't
satisfy the requirements of the specified runtime type.
Returns:
A dictionary of filename relative to the archive root (str) to file contents
(str).
"""
params = ext_runtime.Params(appinfo=info.parsed, deploy=True)
configurator = fingerprinter.IdentifyDirectory(dockerfile_dir, params)
if configurator:
dockerfiles = configurator.GenerateConfigData()
return {d.filename: d.contents for d in dockerfiles}
else:
raise UnsatisfiedRequirementsError(
'Your application does not satisfy all of the requirements for a '
'runtime of type [{0}]. Please correct the errors and try '
'again.'.format(info.runtime))
def _GetSourceContextsForUpload(source_dir):
"""Gets source context file information.
Args:
source_dir: str, path to the service's source directory
Returns:
A dict of filename to (str) source context file contents.
"""
source_contexts = {}
# Error message in case of failure.
m = ('Could not generate [{name}]: {error}\n'
'Stackdriver Debugger may not be configured or enabled on this '
'application. See https://cloud.google.com/debugger/ for more '
'information.')
try:
contexts = context_util.CalculateExtendedSourceContexts(source_dir)
except context_util.GenerateSourceContextError as e:
log.info(m.format(name=context_util.CONTEXT_FILENAME, error=e))
return source_contexts
try:
context = context_util.BestSourceContext(contexts)
source_contexts[context_util.CONTEXT_FILENAME] = six.text_type(
json.dumps(context))
except KeyError as e:
log.info(m.format(name=context_util.CONTEXT_FILENAME, error=e))
return source_contexts
def _GetDomainAndDisplayId(project_id):
"""Returns tuple (displayed app id, domain)."""
l = project_id.split(':')
if len(l) == 1:
return l[0], None
return l[1], l[0]
def _GetImageName(project, service, version, gcr_domain):
"""Returns image tag according to App Engine convention."""
display, domain = _GetDomainAndDisplayId(project)
return (config.DOCKER_IMAGE_NAME_DOMAIN_FORMAT if domain
else config.DOCKER_IMAGE_NAME_FORMAT).format(
gcr_domain=gcr_domain,
display=display,
domain=domain,
service=service,
version=version)
def _GetYamlPath(source_dir, service_path, skip_files, gen_files):
"""Returns the yaml path, optionally updating gen_files.
Args:
source_dir: str, the absolute path to the root of the application directory.
service_path: str, the absolute path to the service YAML file
skip_files: appengine.api.Validation._RegexStr, the validated regex object
from the service info file.
gen_files: dict, the dict of files to generate. May be updated if a file
needs to be generated.
Returns:
str, the relative path to the service YAML file that should be used for
build.
"""
if files.IsDirAncestorOf(source_dir, service_path):
rel_path = os.path.relpath(service_path, start=source_dir)
if not util.ShouldSkip(skip_files, rel_path):
return rel_path
yaml_contents = files.ReadFileContents(service_path)
# Use a checksum to ensure file uniqueness, not for security reasons.
checksum = files.Checksum().AddContents(yaml_contents.encode()).HexDigest()
generated_path = '_app_{}.yaml'.format(checksum)
gen_files[generated_path] = yaml_contents
return generated_path
def BuildAndPushDockerImage(
project,
service,
upload_dir,
source_files,
version_id,
code_bucket_ref,
gcr_domain,
runtime_builder_strategy=runtime_builders.RuntimeBuilderStrategy.NEVER,
parallel_build=False,
use_flex_with_buildpacks=False):
"""Builds and pushes a set of docker images.
Args:
project: str, The project being deployed to.
service: ServiceYamlInfo, The parsed service config.
upload_dir: str, path to the service's upload directory
source_files: [str], relative paths to upload.
version_id: The version id to deploy these services under.
code_bucket_ref: The reference to the GCS bucket where the source will be
uploaded.
gcr_domain: str, Cloud Registry domain, determines the physical location
of the image. E.g. `us.gcr.io`.
runtime_builder_strategy: runtime_builders.RuntimeBuilderStrategy, whether
to use the new CloudBuild-based runtime builders (alternative is old
externalized runtimes).
parallel_build: bool, if True, enable parallel build and deploy.
use_flex_with_buildpacks: bool, if true, use the build-image and
run-image built through buildpacks.
Returns:
BuildArtifact, Representing the pushed container image or in-progress build.
Raises:
DockerfileError: if a Dockerfile is present, but the runtime is not
"custom".
NoDockerfileError: Raised if a user didn't supply a Dockerfile and chose a
custom runtime.
UnsatisfiedRequirementsError: Raised if the code in the directory doesn't
satisfy the requirements of the specified runtime type.
ValueError: if an unrecognized runtime_builder_strategy is given
"""
needs_dockerfile = _NeedsDockerfile(service, upload_dir)
use_runtime_builders = ShouldUseRuntimeBuilders(service,
runtime_builder_strategy,
needs_dockerfile)
# Nothing to do if this is not an image-based deployment.
if not service.RequiresImage():
return None
log.status.Print(
'Building and pushing image for service [{service}]'
.format(service=service.module))
gen_files = dict(_GetSourceContextsForUpload(upload_dir))
if needs_dockerfile and not use_runtime_builders:
# The runtime builders will generate a Dockerfile in the Cloud, so we only
# need to do this if use_runtime_builders is True
gen_files.update(_GetDockerfiles(service, upload_dir))
image = docker_image.Image(
dockerfile_dir=upload_dir,
repo=_GetImageName(project, service.module, version_id, gcr_domain),
nocache=False,
tag=config.DOCKER_IMAGE_TAG)
metrics.CustomTimedEvent(metric_names.CLOUDBUILD_UPLOAD_START)
object_ref = storage_util.ObjectReference.FromBucketRef(
code_bucket_ref, image.tagged_repo)
relative_yaml_path = _GetYamlPath(upload_dir, service.file,
service.parsed.skip_files, gen_files)
try:
cloud_build.UploadSource(upload_dir, source_files, object_ref,
gen_files=gen_files)
except (OSError, IOError) as err:
if platforms.OperatingSystem.IsWindows():
if err.filename and len(err.filename) > _WINDOWS_MAX_PATH:
raise WindowMaxPathError(err.filename)
raise
metrics.CustomTimedEvent(metric_names.CLOUDBUILD_UPLOAD)
if use_runtime_builders:
builder_reference = runtime_builders.FromServiceInfo(
service, upload_dir, use_flex_with_buildpacks)
log.info('Using runtime builder [%s]', builder_reference.build_file_uri)
builder_reference.WarnIfDeprecated()
yaml_path = util.ConvertToPosixPath(relative_yaml_path)
substitute = {
'_OUTPUT_IMAGE': image.tagged_repo,
'_GAE_APPLICATION_YAML_PATH': yaml_path,
}
if use_flex_with_buildpacks:
python_version = yaml_parsing.GetRuntimeConfigAttr(
service.parsed, 'python_version')
if yaml_parsing.GetRuntimeConfigAttr(service.parsed, 'python_version'):
substitute['_GOOGLE_RUNTIME_VERSION'] = python_version
build = builder_reference.LoadCloudBuild(substitute)
else:
build = cloud_build.GetDefaultBuild(image.tagged_repo)
build = cloud_build.FixUpBuild(build, object_ref)
return _SubmitBuild(build, image, project, parallel_build)
def _SubmitBuild(build, image, project, parallel_build):
"""Builds and pushes a set of docker images.
Args:
build: A fixed up Build object.
image: docker_image.Image, A docker image.
project: str, The project being deployed to.
parallel_build: bool, if True, enable parallel build and deploy.
Returns:
BuildArtifact, Representing the pushed container image or in-progress build.
"""
build_timeout = cloud_build.GetServiceTimeoutSeconds(
properties.VALUES.app.cloud_build_timeout.Get())
if build_timeout and build_timeout > MAX_PARALLEL_BUILD_TIME:
parallel_build = False
log.info(
'Property cloud_build_timeout configured to [{0}], which exceeds '
'the maximum build time for parallelized beta deployments of [{1}] '
'seconds. Performing serial deployment.'.format(
build_timeout, MAX_PARALLEL_BUILD_TIME))
if parallel_build:
metrics.CustomTimedEvent(metric_names.CLOUDBUILD_EXECUTE_ASYNC_START)
build_op = cloudbuild_build.CloudBuildClient().ExecuteCloudBuildAsync(
build, project=project)
return app_build.BuildArtifact.MakeBuildIdArtifactFromOp(build_op)
else:
metrics.CustomTimedEvent(metric_names.CLOUDBUILD_EXECUTE_START)
cloudbuild_build.CloudBuildClient().ExecuteCloudBuild(
build, project=project)
metrics.CustomTimedEvent(metric_names.CLOUDBUILD_EXECUTE)
return app_build.BuildArtifact.MakeImageArtifact(image.tagged_repo)
def DoPrepareManagedVms(gae_client):
"""Call an API to prepare the for App Engine Flexible."""
metrics.CustomTimedEvent(metric_names.PREPARE_ENV_START)
try:
message = 'If this is your first deployment, this may take a while'
with progress_tracker.ProgressTracker(message):
# Note: this doesn't actually boot the VM, it just prepares some stuff
# for the project via an undocumented Admin API.
gae_client.PrepareVmRuntime()
log.status.Print()
except util.RPCError as err:
# Any failures later due to an unprepared project will be noisy, so it's
# okay not to fail here.
log.warning(
("We couldn't validate that your project is ready to deploy to App "
'Engine Flexible Environment. If deployment fails, please check the '
'following message and try again:\n') + six.text_type(err))
metrics.CustomTimedEvent(metric_names.PREPARE_ENV)
def PossiblyEnableFlex(project):
"""Attempts to enable the Flexible Environment API on the project.
Possible scenarios:
-If Flexible Environment is already enabled, success.
-If Flexible Environment API is not yet enabled, attempts to enable it. If
that succeeds, success.
-If the account doesn't have permissions to confirm that the Flexible
Environment API is or isn't enabled on this project, succeeds with a warning.
-If the account is a service account, adds an additional warning that
the Service Management API may need to be enabled.
-If the Flexible Environment API is not enabled on the project and the attempt
to enable it fails, raises PrepareFailureError.
Args:
project: str, the project ID.
Raises:
PrepareFailureError: if enabling the API fails with a 403 or 404 error code.
googlecloudsdk.api_lib.util.exceptions.HttpException: miscellaneous errors
returned by server.
"""
try:
enable_api.EnableServiceIfDisabled(project,
'appengineflex.googleapis.com')
except s_exceptions.GetServicePermissionDeniedException:
# If we can't find out whether the Flexible API is enabled, proceed with
# a warning.
warning = FLEXIBLE_SERVICE_VERIFY_WARNING.format(project)
# If user is using a service account, add more info about what might
# have gone wrong.
credential = c_store.LoadIfEnabled(use_google_auth=True)
if credential and creds.IsServiceAccountCredentials(credential):
warning += '\n\n{}'.format(FLEXIBLE_SERVICE_VERIFY_WITH_SERVICE_ACCOUNT)
log.warning(warning)
except s_exceptions.EnableServiceException:
# If enabling the Flexible API fails due to a permissions error, the
# deployment fails.
raise PrepareFailureError(PREPARE_FAILURE_MSG.format(project))
except apitools_exceptions.HttpError as err:
# The deployment should also fail if there are unforeseen errors in
# enabling the Flexible API. If so, display detailed information.
raise api_lib_exceptions.HttpException(
err, error_format=('Error [{status_code}] {status_message}'
'{error.details?'
'\nDetailed error information:\n{?}}'))
def UseSsl(service_info):
"""Returns whether the root URL for an application is served over HTTPS.
More specifically, returns the 'secure' setting of the handler that will serve
the application. This can be 'always', 'optional', or 'never', depending on
when the URL is served over HTTPS.
Will miss a small number of cases, but HTTP is always okay (an HTTP URL to an
HTTPS-only service will result in a redirect).
Args:
service_info: ServiceYamlInfo, the service configuration.
Returns:
str, the 'secure' setting of the handler for the root URL.
"""
if service_info.is_ti_runtime and not service_info.parsed.handlers:
return appinfo.SECURE_HTTP_OR_HTTPS
for handler in service_info.parsed.handlers:
try:
if re.match(handler.url + '$', '/'):
return handler.secure
except re.error:
# AppEngine uses POSIX Extended regular expressions, which are not 100%
# compatible with Python's re module.
pass
return appinfo.SECURE_HTTP
def GetAppHostname(app=None, app_id=None, service=None, version=None,
use_ssl=appinfo.SECURE_HTTP, deploy=True):
"""Returns the hostname of the given version of the deployed app.
Args:
app: Application resource. One of {app, app_id} must be given.
app_id: str, project ID. One of {app, app_id} must be given. If both are
provided, the hostname from app is preferred.
service: str, the (optional) service being deployed
version: str, the deployed version ID (omit to get the default version URL).
use_ssl: bool, whether to construct an HTTPS URL.
deploy: bool, if this is called during a deployment.
Returns:
str. Constructed URL.
Raises:
TypeError: if neither an app nor an app_id is provided
"""
if not app and not app_id:
raise TypeError('Must provide an application resource or application ID.')
version = version or ''
service_name = service or ''
if service == DEFAULT_SERVICE:
service_name = ''
if not app:
api_client = appengine_api_client.AppengineApiClient.GetApiClient()
app = api_client.GetApplication()
if app:
app_id, domain = app.defaultHostname.split('.', 1)
# Normally, AppEngine URLs are of the form
# 'http[s]://version.service.app.appspot.com'. However, the SSL certificate
# for appspot.com is not valid for subdomains of subdomains of appspot.com
# (e.g. 'https://app.appspot.com/' is okay; 'https://service.app.appspot.com/'
# is not). To deal with this, AppEngine recognizes URLs like
# 'http[s]://version-dot-service-dot-app.appspot.com/'.
#
# This works well as long as the domain name part constructed in this fashion
# is less than 63 characters long, as per the DNS spec. If the domain name
# part is longer than that, we are forced to use the URL with an invalid
# certificate.
#
# We've tried to do the best possible thing in every case here.
subdomain_parts = list(filter(bool, [version, service_name, app_id]))
scheme = 'http'
if use_ssl == appinfo.SECURE_HTTP:
subdomain = '.'.join(subdomain_parts)
scheme = 'http'
else:
subdomain = ALT_SEPARATOR.join(subdomain_parts)
if len(subdomain) <= MAX_DNS_LABEL_LENGTH:
scheme = 'https'
else:
if deploy:
format_parts = ['$VERSION_ID', '$SERVICE_ID', '$APP_ID']
subdomain_format = ALT_SEPARATOR.join(
[j for (i, j) in zip([version, service_name, app_id], format_parts)
if i])
msg = ('This deployment will result in an invalid SSL certificate for '
'service [{0}]. The total length of your subdomain in the '
'format {1} should not exceed {2} characters. Please verify '
'that the certificate corresponds to the parent domain of your '
'application when you connect.').format(service,
subdomain_format,
MAX_DNS_LABEL_LENGTH)
log.warning(msg)
subdomain = '.'.join(subdomain_parts)
if use_ssl == appinfo.SECURE_HTTP_OR_HTTPS:
scheme = 'http'
elif use_ssl == appinfo.SECURE_HTTPS:
if not deploy:
msg = ('Most browsers will reject the SSL certificate for '
'service [{0}].').format(service)
log.warning(msg)
scheme = 'https'
return '{0}://{1}.{2}'.format(scheme, subdomain, domain)
DEFAULT_DEPLOYABLE = 'app.yaml'

View File

@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*- #
# Copyright 2014 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encapsulation of a docker image."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
class Image(object):
"""Docker image that requires building and should be removed afterwards."""
def __init__(self, dockerfile_dir=None, repo=None, tag=None, nocache=False,
rm=True):
"""Initializer for Image.
Args:
dockerfile_dir: str, Path to the directory with the Dockerfile.
repo: str, Repository name to be applied to the image in case of
successful build.
tag: str, Repository tag to be applied to the image in case of successful
build.
nocache: boolean, True if cache should not be used when building the
image.
rm: boolean, True if intermediate images should be removed after a
successful build. Default value is set to True because this is the
default value used by "docker build" command.
"""
self._dockerfile_dir = dockerfile_dir
self._repo = repo
self._tag = tag
self._nocache = nocache
self._rm = rm
# Will be set during Build() method.
self._id = None
@property
def dockerfile_dir(self):
"""Returns the directory the image is to be built from."""
return self._dockerfile_dir
@property
def id(self):
"""Returns 64 hexadecimal digit string identifying the image."""
# Might also be a first 12-characters shortcut.
return self._id
@property
def repo(self):
"""Returns image repo string."""
return self._repo
@property
def tag(self):
"""Returns image tag string."""
return self._tag
@property
def tagged_repo(self):
"""Returns image repo string with tag, if it exists."""
return '{0}:{1}'.format(self.repo, self.tag) if self.tag else self.repo

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Auxiliary environment information about App Engine."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
import enum
from googlecloudsdk.api_lib.app import runtime_registry
NODE_TI_RUNTIME_EXPR = re.compile(r'nodejs\d*')
PHP_TI_RUNTIME_EXPR = re.compile(r'php[789]\d*')
PYTHON_TI_RUNTIME_EXPR = re.compile(r'python3\d*')
# Allow things like go110 and g110beta1
GO_TI_RUNTIME_EXPR = re.compile(r'go1\d\d(\w+\d)?')
# Java 7, 8 still allows handlers
JAVA_TI_RUNTIME_EXPR = re.compile(r'java[123456]\d*')
class Environment(enum.Enum):
"""Enum for different application environments.
STANDARD corresponds to App Engine Standard applications.
FLEX corresponds to any App Engine `env: flex` applications.
MANAGED_VMS corresponds to `vm: true` applications.
"""
STANDARD = 1
MANAGED_VMS = 2
FLEX = 3
def GetTiRuntimeRegistry():
"""A simple registry whose `Get()` method answers True if runtime is Ti."""
return runtime_registry.Registry(_TI_RUNTIME_REGISTRY, default=False)
STANDARD = Environment.STANDARD
FLEX = Environment.FLEX
MANAGED_VMS = Environment.MANAGED_VMS
_TI_RUNTIME_REGISTRY = {
runtime_registry.RegistryEntry(NODE_TI_RUNTIME_EXPR, {STANDARD}): True,
runtime_registry.RegistryEntry(PHP_TI_RUNTIME_EXPR, {STANDARD}): True,
runtime_registry.RegistryEntry(PYTHON_TI_RUNTIME_EXPR, {STANDARD}): True,
runtime_registry.RegistryEntry(GO_TI_RUNTIME_EXPR, {STANDARD}): True,
runtime_registry.RegistryEntry(JAVA_TI_RUNTIME_EXPR, {STANDARD}): True,
}

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module holds exceptions raised by api lib."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import exceptions
class Error(exceptions.Error):
"""Base error for this module."""
class ConfigError(Error):
"""Raised when unable to parse a config file."""
def __init__(self, message=None, **kwargs):
message = message or 'Config Error.'
super(ConfigError, self).__init__(message, **kwargs)

View File

@@ -0,0 +1,124 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adapter to use externalized runtimes loaders from gcloud."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from gae_ext_runtime import ext_runtime
from googlecloudsdk.core import config
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.console import console_io
class NoRuntimeRootError(exceptions.Error):
"""Raised when we can't determine where the runtimes are."""
def _GetRuntimeDefDir():
runtime_root = properties.VALUES.app.runtime_root.Get()
if runtime_root:
return runtime_root
raise NoRuntimeRootError('Unable to determine the root directory where '
'GAE runtimes are stored. Please define '
'the CLOUDSDK_APP_RUNTIME_ROOT environmnent '
'variable.')
class GCloudExecutionEnvironment(ext_runtime.ExecutionEnvironment):
"""ExecutionEnvironment implemented using gcloud's core functions."""
def GetPythonExecutable(self):
return execution_utils.GetPythonExecutable()
def CanPrompt(self):
return console_io.CanPrompt()
def PromptResponse(self, message):
return console_io.PromptResponse(message)
def Print(self, message):
return log.status.Print(message)
class CoreRuntimeLoader(object):
"""A loader stub for the core runtimes.
The externalized core runtimes are currently distributed with the cloud sdk.
This class encapsulates the name of a core runtime to avoid having to load
it at module load time. Instead, the wrapped runtime is demand-loaded when
the Fingerprint() method is called.
"""
def __init__(self, name, visible_name, allowed_runtime_names):
self._name = name
self._rep = None
self._visible_name = visible_name
self._allowed_runtime_names = allowed_runtime_names
# These need to be named this way because they're constants in the
# non-externalized implementation.
# pylint:disable=invalid-name
@property
def ALLOWED_RUNTIME_NAMES(self):
return self._allowed_runtime_names
# pylint:disable=invalid-name
@property
def NAME(self):
return self._visible_name
def Fingerprint(self, path, params):
if not self._rep:
path_to_runtime = os.path.join(_GetRuntimeDefDir(), self._name)
self._rep = ext_runtime.ExternalizedRuntime.Load(
path_to_runtime, GCloudExecutionEnvironment())
return self._rep.Fingerprint(path, params)
_PROMPTS_DISABLED_ERROR_MESSAGE = (
'("disable_prompts" set to true, run "gcloud config set disable_prompts '
'False" to fix this)')
def GetNonInteractiveErrorMessage():
"""Returns useful instructions when running non-interactive.
Certain fingerprinting modules require interactive functionality. It isn't
always obvious why gcloud is running in non-interactive mode (e.g. when
"disable_prompts" is set) so this returns an appropriate addition to the
error message in these circumstances.
Returns:
(str) The appropriate error message snippet.
"""
if properties.VALUES.core.disable_prompts.GetBool():
# We add a leading space to the raw message so that it meshes well with
# its display context.
return ' ' + _PROMPTS_DISABLED_ERROR_MESSAGE
else:
# The other case for non-interactivity (running detached from a terminal)
# should be obvious.
return ''

View File

@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*- #
# Copyright 2014 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Magic constants for images module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
# The version of the docker API the docker-py client uses.
# Warning: other versions might have different return values for some functions.
DOCKER_PY_VERSION = 'auto'
# Timeout of HTTP request from docker-py client to docker daemon, in seconds.
DOCKER_D_REQUEST_TIMEOUT = 300
DOCKER_IMAGE_NAME_FORMAT = (
'{gcr_domain}/{display}/appengine/{service}.{version}')
DOCKER_IMAGE_TAG = 'latest'
DOCKER_IMAGE_NAME_DOMAIN_FORMAT = (
'{gcr_domain}/{domain}/{display}/appengine/{service}.{version}')
# Name of the a Dockerfile.
DOCKERFILE = 'Dockerfile'
# A map of runtimes values if they need to be overwritten to match our
# base Docker images naming rules.
CANONICAL_RUNTIMES = {'java7': 'java'}

View File

@@ -0,0 +1,273 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for manipulating GCE instances running an App Engine project."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.console import console_io
from six.moves import filter # pylint: disable=redefined-builtin
from six.moves import map # pylint: disable=redefined-builtin
class InvalidInstanceSpecificationError(exceptions.Error):
pass
class SelectInstanceError(exceptions.Error):
pass
class Instance(object):
"""Value class for instances running the current App Engine project."""
# TODO(b/27900246): Once API supports "Get" verb, convert to use resource
# parser.
_INSTANCE_NAME_PATTERN = ('apps/(?P<project>.*)/'
'services/(?P<service>.*)/'
'versions/(?P<version>.*)/'
'instances/(?P<instance>.*)')
def __init__(self, service, version, id_, instance=None):
self.service = service
self.version = version
self.id = id_
self.instance = instance # The Client API instance object
@classmethod
def FromInstanceResource(cls, instance):
match = re.match(cls._INSTANCE_NAME_PATTERN, instance.name)
service = match.group('service')
version = match.group('version')
return cls(service, version, instance.id, instance)
@classmethod
def FromResourcePath(cls, path, service=None, version=None):
"""Convert a resource path into an AppEngineInstance.
A resource path is of the form '<service>/<version>/<instance>'.
'<service>' and '<version>' can be omitted, in which case they are None in
the resulting instance.
>>> (AppEngineInstance.FromResourcePath('a/b/c') ==
... AppEngineInstance('a', 'b', 'c'))
True
>>> (AppEngineInstance.FromResourcePath('b/c', service='a') ==
... AppEngineInstance('a', 'b', 'c'))
True
>>> (AppEngineInstance.FromResourcePath('c', service='a', version='b') ==
... AppEngineInstance('a', 'b', 'c'))
True
Args:
path: str, the resource path
service: the service of the instance (replaces the service from the
resource path)
version: the version of the instance (replaces the version from the
resource path)
Returns:
AppEngineInstance, an AppEngineInstance representing the path
Raises:
InvalidInstanceSpecificationError: if the instance is over- or
under-specified
"""
parts = path.split('/')
if len(parts) == 1:
path_service, path_version, instance = None, None, parts[0]
elif len(parts) == 2:
path_service, path_version, instance = None, parts[0], parts[1]
elif len(parts) == 3:
path_service, path_version, instance = parts
else:
raise InvalidInstanceSpecificationError(
'Instance resource path is incorrectly specified. '
'Please provide at most one service, version, and instance id, '
'.\n\n'
'You provided:\n' + path)
if path_service and service and path_service != service:
raise InvalidInstanceSpecificationError(
'Service [{0}] is inconsistent with specified instance [{1}].'.format(
service, path))
service = service or path_service
if path_version and version and path_version != version:
raise InvalidInstanceSpecificationError(
'Version [{0}] is inconsistent with specified instance [{1}].'.format(
version, path))
version = version or path_version
return cls(service, version, instance)
def __eq__(self, other):
return (type(self) is type(other) and
self.service == other.service and
self.version == other.version and
self.id == other.id)
def __ne__(self, other):
return not self == other
# needed for set comparisons in tests
def __hash__(self):
return hash((self.service, self.version, self.id))
def __str__(self):
return '/'.join(filter(bool, [self.service, self.version, self.id]))
def __cmp__(self, other):
return cmp((self.service, self.version, self.id),
(other.service, other.version, other.id))
def FilterInstances(instances, service=None, version=None, instance=None):
"""Filter a list of App Engine instances.
Args:
instances: list of AppEngineInstance, all App Engine instances
service: str, the name of the service to filter by or None to match all
services
version: str, the name of the version to filter by or None to match all
versions
instance: str, the instance id to filter by or None to match all versions.
Returns:
list of instances matching the given filters
"""
matching_instances = []
for provided_instance in instances:
if ((not service or provided_instance.service == service) and
(not version or provided_instance.version == version) and
(not instance or provided_instance.id == instance)):
matching_instances.append(provided_instance)
return matching_instances
def GetMatchingInstance(instances, service=None, version=None, instance=None):
"""Return exactly one matching instance.
If instance is given, filter down based on the given criteria (service,
version, instance) and return the matching instance (it is an error unless
exactly one instance matches).
Otherwise, prompt the user to select the instance interactively.
Args:
instances: list of AppEngineInstance, all instances to select from
service: str, a service to filter by or None to include all services
version: str, a version to filter by or None to include all versions
instance: str, an instance ID to filter by. If not given, the instance will
be selected interactively.
Returns:
AppEngineInstance, an instance from the given list.
Raises:
InvalidInstanceSpecificationError: if no matching instances or more than one
matching instance were found.
"""
if not instance:
return SelectInstanceInteractive(instances, service=service,
version=version)
matching = FilterInstances(instances, service, version, instance)
if len(matching) > 1:
raise InvalidInstanceSpecificationError(
'More than one instance matches the given specification.\n\n'
'Matching instances: {0}'.format(list(sorted(map(str, matching)))))
elif not matching:
raise InvalidInstanceSpecificationError(
'No instances match the given specification.\n\n'
'All instances: {0}'.format(list(sorted(map(str, instances)))))
return matching[0]
def SelectInstanceInteractive(all_instances, service=None, version=None):
"""Interactively choose an instance from a provided list.
Example interaction:
Which service?
[1] default
[2] service1
Please enter your numeric choice: 1
Which version?
[1] v1
[2] v2
Please enter your numeric choice: 1
Which instance?
[1] i1
[2] i2
Please enter your numeric choice: 1
Skips any prompts with only one option.
Args:
all_instances: list of AppEngineInstance, the list of instances to drill
down on.
service: str. If provided, skip the service prompt.
version: str. If provided, skip the version prompt.
Returns:
AppEngineInstance, the selected instance from the list.
Raises:
SelectInstanceError: if no versions matching the criteria can be found or
prompts are disabled.
"""
if properties.VALUES.core.disable_prompts.GetBool():
raise SelectInstanceError(
'Cannot interactively select instances with prompts disabled.')
# Defined here to close over all_instances for the error message
def _PromptOptions(options, type_):
"""Given an iterable options of type type_, prompt and return one."""
options = sorted(set(options), key=str)
if len(options) > 1:
idx = console_io.PromptChoice(options, message='Which {0}?'.format(type_))
elif len(options) == 1:
idx = 0
log.status.Print('Choosing [{0}] for {1}.\n'.format(options[0], type_))
else:
if all_instances:
msg = ('No instances could be found matching the given criteria.\n\n'
'All instances:\n' +
'\n'.join(
map('* [{0}]'.format, sorted(all_instances, key=str))))
else:
msg = 'No instances were found for the current project [{0}].'.format(
properties.VALUES.core.project.Get(required=True))
raise SelectInstanceError(msg)
return options[idx]
matching_instances = FilterInstances(all_instances, service, version)
service = _PromptOptions((i.service for i in matching_instances), 'service')
matching_instances = FilterInstances(matching_instances, service=service)
version = _PromptOptions((i.version for i in matching_instances), 'version')
matching_instances = FilterInstances(matching_instances, version=version)
return _PromptOptions(matching_instances, 'instance')

View File

@@ -0,0 +1,283 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General formatting utils, App Engine specific formatters."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.logging import util
from googlecloudsdk.core import log
from googlecloudsdk.core import resources
from googlecloudsdk.core.util import times
import six
LOG_LEVELS = ['critical', 'error', 'warning', 'info', 'debug', 'any']
# Request logs come from different sources if the app is Flex or Standard.
FLEX_REQUEST = 'nginx.request'
STANDARD_REQUEST = 'request_log'
DEFAULT_LOGS = ['stderr', 'stdout', 'crash.log',
FLEX_REQUEST, STANDARD_REQUEST]
NGINX_LOGS = [
'appengine.googleapis.com/nginx.request',
'appengine.googleapis.com/nginx.health_check']
def GetFilters(project, log_sources, service=None, version=None, level='any'):
"""Returns filters for App Engine app logs.
Args:
project: string name of project ID.
log_sources: List of streams to fetch logs from.
service: String name of service to fetch logs from.
version: String name of version to fetch logs from.
level: A string representing the severity of logs to fetch.
Returns:
A list of filter strings.
"""
filters = ['resource.type="gae_app"']
if service:
filters.append('resource.labels.module_id="{0}"'.format(service))
if version:
filters.append('resource.labels.version_id="{0}"'.format(version))
if level != 'any':
filters.append('severity>={0}'.format(level.upper()))
log_ids = []
for log_type in sorted(log_sources):
log_ids.append('appengine.googleapis.com/{0}'.format(log_type))
if log_type in ('stderr', 'stdout'):
log_ids.append(log_type)
res = resources.REGISTRY.Parse(
project, collection='appengine.projects').RelativeName()
filters.append(_LogFilterForIds(log_ids, res))
return filters
def _LogFilterForIds(log_ids, parent):
"""Constructs a log filter expression from the log_ids and parent name."""
if not log_ids:
return None
log_names = ['"{0}"'.format(util.CreateLogResourceName(parent, log_id))
for log_id in log_ids]
log_names = ' OR '.join(log_names)
if len(log_ids) > 1:
log_names = '(%s)' % log_names
return 'logName=%s' % log_names
def FormatAppEntry(entry):
"""App Engine formatter for `LogPrinter`.
Args:
entry: A log entry message emitted from the V2 API client.
Returns:
A string representing the entry or None if there was no text payload.
"""
# TODO(b/36056460): Output others than text here too?
if entry.resource.type != 'gae_app':
return None
if entry.protoPayload:
text = six.text_type(entry.protoPayload)
elif entry.jsonPayload:
text = six.text_type(entry.jsonPayload)
else:
text = entry.textPayload
service, version = _ExtractServiceAndVersion(entry)
return '{service}[{version}] {text}'.format(service=service,
version=version,
text=text)
def FormatRequestLogEntry(entry):
"""App Engine request_log formatter for `LogPrinter`.
Args:
entry: A log entry message emitted from the V2 API client.
Returns:
A string representing the entry if it is a request entry.
"""
if entry.resource.type != 'gae_app':
return None
log_id = util.ExtractLogId(entry.logName)
if log_id != 'appengine.googleapis.com/request_log':
return None
service, version = _ExtractServiceAndVersion(entry)
def GetStr(key):
return next((x.value.string_value for x in
entry.protoPayload.additionalProperties
if x.key == key), '-')
def GetInt(key):
return next((x.value.integer_value for x in
entry.protoPayload.additionalProperties
if x.key == key), '-')
msg = ('"{method} {resource} {http_version}" {status}'
.format(
method=GetStr('method'),
resource=GetStr('resource'),
http_version=GetStr('httpVersion'),
status=GetInt('status')))
return '{service}[{version}] {msg}'.format(service=service,
version=version,
msg=msg)
def FormatNginxLogEntry(entry):
"""App Engine nginx.* formatter for `LogPrinter`.
Args:
entry: A log entry message emitted from the V2 API client.
Returns:
A string representing the entry if it is a request entry.
"""
if entry.resource.type != 'gae_app':
return None
log_id = util.ExtractLogId(entry.logName)
if log_id not in NGINX_LOGS:
return None
service, version = _ExtractServiceAndVersion(entry)
msg = ('"{method} {resource}" {status}'
.format(
method=entry.httpRequest.requestMethod or '-',
resource=entry.httpRequest.requestUrl or '-',
status=entry.httpRequest.status or '-'))
return '{service}[{version}] {msg}'.format(service=service,
version=version,
msg=msg)
def _ExtractServiceAndVersion(entry):
"""Extract service and version from a App Engine log entry.
Args:
entry: An App Engine log entry.
Returns:
A 2-tuple of the form (service_id, version_id)
"""
# TODO(b/36051034): If possible, extract instance ID too
ad_prop = entry.resource.labels.additionalProperties
service = next(x.value
for x in ad_prop
if x.key == 'module_id')
version = next(x.value
for x in ad_prop
if x.key == 'version_id')
return (service, version)
class LogPrinter(object):
"""Formats V2 API log entries to human readable text on a best effort basis.
A LogPrinter consists of a collection of formatter functions which attempts
to format specific log entries in a human readable form. The `Format` method
safely returns a human readable string representation of a log entry, even if
the provided formatters fails.
The output format is `{timestamp} {log_text}`, where `timestamp` has a
configurable but consistent format within a LogPrinter whereas `log_text` is
emitted from one of its formatters (and truncated if necessary).
See https://cloud.google.com/logging/docs/api/introduction_v2
Attributes:
api_time_format: str, the output format to print. See datetime.strftime()
max_length: The maximum length of a formatted log entry after truncation.
"""
def __init__(self, api_time_format='%Y-%m-%d %H:%M:%S', max_length=None):
self.formatters = []
self.api_time_format = api_time_format
self.max_length = max_length
def Format(self, entry):
"""Safely formats a log entry into human readable text.
Args:
entry: A log entry message emitted from the V2 API client.
Returns:
A string without line breaks respecting the `max_length` property.
"""
text = self._LogEntryToText(entry)
text = text.strip().replace('\n', ' ')
try:
time = times.FormatDateTime(times.ParseDateTime(entry.timestamp),
self.api_time_format)
except times.Error:
log.warning('Received timestamp [{0}] does not match expected'
' format.'.format(entry.timestamp))
time = '????-??-?? ??:??:??'
out = '{timestamp} {log_text}'.format(
timestamp=time,
log_text=text)
if self.max_length and len(out) > self.max_length:
out = out[:self.max_length - 3] + '...'
return out
def RegisterFormatter(self, formatter):
"""Attach a log entry formatter function to the printer.
Note that if multiple formatters are attached to the same printer, the first
added formatter that successfully formats the entry will be used.
Args:
formatter: A formatter function which accepts a single argument, a log
entry. The formatter must either return the formatted log entry as a
string, or None if it is unable to format the log entry.
The formatter is allowed to raise exceptions, which will be caught and
ignored by the printer.
"""
self.formatters.append(formatter)
def _LogEntryToText(self, entry):
"""Use the formatters to convert a log entry to unprocessed text."""
out = None
for fn in self.formatters + [self._FallbackFormatter]:
# pylint:disable=bare-except
try:
out = fn(entry)
if out:
break
except KeyboardInterrupt as e:
raise e
except:
pass
if not out:
log.debug('Could not format log entry: %s %s %s', entry.timestamp,
entry.logName, entry.insertId)
out = ('< UNREADABLE LOG ENTRY {0}. OPEN THE DEVELOPER CONSOLE TO '
'INSPECT. >'.format(entry.insertId))
return out
def _FallbackFormatter(self, entry):
# TODO(b/36057358): Is there better serialization for messages than
# six.text_type()?
if entry.protoPayload:
return six.text_type(entry.protoPayload)
elif entry.jsonPayload:
return six.text_type(entry.jsonPayload)
else:
return entry.textPayload

View File

@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Appengine CSI metric names."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
# Metric names for CSI
# Reserved CSI metric prefix for appengine
_APPENGINE_PREFIX = 'app_deploy_'
# "Start" suffix
START = '_start'
# Time to upload project source tarball to GCS
CLOUDBUILD_UPLOAD = _APPENGINE_PREFIX + 'cloudbuild_upload'
CLOUDBUILD_UPLOAD_START = CLOUDBUILD_UPLOAD + START
# Time to execute Argo Cloud Build request
CLOUDBUILD_EXECUTE = _APPENGINE_PREFIX + 'cloudbuild_execute'
CLOUDBUILD_EXECUTE_START = CLOUDBUILD_EXECUTE + START
CLOUDBUILD_EXECUTE_ASYNC = CLOUDBUILD_EXECUTE + '_async'
CLOUDBUILD_EXECUTE_ASYNC_START = CLOUDBUILD_EXECUTE_ASYNC + START
# Time to copy application files to the application code bucket
COPY_APP_FILES = _APPENGINE_PREFIX + 'copy_app_files'
COPY_APP_FILES_START = COPY_APP_FILES + START
# Time to copy application files to the application code bucket without gsutil.
# No longer used, but may still come in from old versions.
COPY_APP_FILES_NO_GSUTIL = _APPENGINE_PREFIX + 'copy_app_files_no_gsutil'
# Time for a deploy using appengine API
DEPLOY_API = _APPENGINE_PREFIX + 'deploy_api'
DEPLOY_API_START = DEPLOY_API + START
# Time for API request to get the application code bucket.
GET_CODE_BUCKET = _APPENGINE_PREFIX + 'get_code_bucket'
GET_CODE_BUCKET_START = GET_CODE_BUCKET + START
# Time for setting deployed version to default using appengine API
SET_DEFAULT_VERSION_API = (_APPENGINE_PREFIX + 'set_default_version_api')
SET_DEFAULT_VERSION_API_START = SET_DEFAULT_VERSION_API + START
# Time for API request to prepare environment for VMs.
PREPARE_ENV = _APPENGINE_PREFIX + 'prepare_environment'
PREPARE_ENV_START = PREPARE_ENV + START
# Time to update config files.
UPDATE_CONFIG = _APPENGINE_PREFIX + 'update_config'
UPDATE_CONFIG_START = UPDATE_CONFIG + START
# First service deployment
FIRST_SERVICE_DEPLOY = _APPENGINE_PREFIX + 'first_service_deploy'
FIRST_SERVICE_DEPLOY_START = FIRST_SERVICE_DEPLOY + START

View File

@@ -0,0 +1,319 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with long running operations go/long-running-operation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
import enum
from googlecloudsdk.api_lib.app import exceptions as app_exceptions
from googlecloudsdk.api_lib.util import exceptions as api_exceptions
from googlecloudsdk.api_lib.util import requests
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import resources
import six
# Default is to retry every 5 seconds for 1 hour.
DEFAULT_OPERATION_RETRY_INTERVAL = 5
DEFAULT_OPERATION_MAX_TRIES = (60 // DEFAULT_OPERATION_RETRY_INTERVAL) * 60
def CallAndCollectOpErrors(method, *args, **kwargs):
"""Wrapper for method(...) which re-raises operation-style errors.
Args:
method: Original method to call.
*args: Positional arguments to method.
**kwargs: Keyword arguments to method.
Raises:
MiscOperationError: If the method call itself raises one of the exceptions
listed below. Otherwise, the original exception is raised. Preserves
stack trace. Re-uses the error string from original error or in the case
of HttpError, we synthesize human-friendly string from HttpException.
However, HttpException is neither raised nor part of the stack trace.
Returns:
Result of calling method(*args, **kwargs).
"""
try:
return method(*args, **kwargs)
except apitools_exceptions.HttpError as http_err:
# Create HttpException locally only to get its human friendly string
_ReraiseMiscOperationError(api_exceptions.HttpException(http_err))
except (OperationError, OperationTimeoutError, app_exceptions.Error) as err:
_ReraiseMiscOperationError(err)
def _ReraiseMiscOperationError(err):
"""Transform and re-raise error helper."""
exceptions.reraise(MiscOperationError(six.text_type(err)))
class MiscOperationError(exceptions.Error):
"""Wrapper exception for errors treated as operation failures."""
class OperationError(exceptions.Error):
pass
class OperationTimeoutError(exceptions.Error):
pass
class Status(enum.Enum):
PENDING = 1
COMPLETED = 2
ERROR = 3
class Operation(object):
"""Wrapper around Operation response objects for console output.
Attributes:
project: String, name of the project.
id: String, ID of operation.
start_time: String, time the operation started.
status: Status enum, either PENDING, COMPLETED, or Error.
op_resource: messages.Operation, the original Operation resource.
"""
def __init__(self, op_response):
"""Creates the operation wrapper object."""
res = resources.REGISTRY.ParseRelativeName(op_response.name,
'appengine.apps.operations')
self.project = res.appsId
self.id = res.Name()
self.start_time = _GetInsertTime(op_response)
self.status = GetStatus(op_response)
self.op_resource = op_response
def __eq__(self, other):
return (isinstance(other, Operation) and
self.project == other.project and
self.id == other.id and
self.start_time == other.start_time and
self.status == other.status and
self.op_resource == other.op_resource)
def GetStatus(operation):
"""Returns string status for given operation.
Args:
operation: A messages.Operation instance.
Returns:
The status of the operation in string form.
"""
if not operation.done:
return Status.PENDING.name
elif operation.error:
return Status.ERROR.name
else:
return Status.COMPLETED.name
def _GetInsertTime(operation):
"""Finds the insertTime property and return its string form.
Args:
operation: A messages.Operation instance.
Returns:
The time the operation started in string form or None if N/A.
"""
if not operation.metadata:
return None
properties = operation.metadata.additionalProperties
for prop in properties:
if prop.key == 'insertTime':
return prop.value.string_value
class AppEngineOperationPoller(waiter.OperationPoller):
"""A poller for appengine operations."""
def __init__(self, operation_service, operation_metadata_type=None):
"""Sets up poller for appengine operations.
Args:
operation_service: apitools.base.py.base_api.BaseApiService, api service
for retrieving information about ongoing operation.
operation_metadata_type: Message class for the Operation metadata (for
instance, OperationMetadataV1, or OperationMetadataV1Beta).
"""
self.operation_service = operation_service
self.operation_metadata_type = operation_metadata_type
self.warnings_seen = set()
def IsDone(self, operation):
"""Overrides."""
self._LogNewWarnings(operation)
if operation.done:
log.debug('Operation [{0}] complete. Result: {1}'.format(
operation.name,
json.dumps(encoding.MessageToDict(operation), indent=4)))
if operation.error:
raise OperationError(requests.ExtractErrorMessage(
encoding.MessageToPyValue(operation.error)))
return True
log.debug('Operation [{0}] not complete. Waiting to retry.'.format(
operation.name))
return False
def Poll(self, operation_ref):
"""Overrides.
Args:
operation_ref: googlecloudsdk.core.resources.Resource.
Returns:
fetched operation message.
"""
request_type = self.operation_service.GetRequestType('Get')
request = request_type(name=operation_ref.RelativeName())
operation = self.operation_service.Get(request)
self._LogNewWarnings(operation)
return operation
def _LogNewWarnings(self, operation):
if self.operation_metadata_type:
# Log any new warnings to the end user.
new_warnings = GetWarningsFromOperation(
operation, self.operation_metadata_type) - self.warnings_seen
for warning in new_warnings:
log.warning(warning + '\n')
self.warnings_seen.add(warning)
def GetResult(self, operation):
"""Simply returns the operation.
Args:
operation: api_name_messages.Operation.
Returns:
the 'response' field of the Operation.
"""
return operation
class AppEngineOperationBuildPoller(AppEngineOperationPoller):
"""Waits for a build to be present, or for the operation to finish."""
def __init__(self, operation_service, operation_metadata_type):
"""Sets up poller for appengine operations.
Args:
operation_service: apitools.base.py.base_api.BaseApiService, api service
for retrieving information about ongoing operation.
operation_metadata_type: Message class for the Operation metadata (for
instance, OperationMetadataV1, or OperationMetadataV1Beta).
"""
super(AppEngineOperationBuildPoller, self).__init__(operation_service,
operation_metadata_type)
def IsDone(self, operation):
if GetBuildFromOperation(operation, self.operation_metadata_type):
return True
return super(AppEngineOperationBuildPoller, self).IsDone(operation)
def GetMetadataFromOperation(operation, operation_metadata_type):
if not operation.metadata:
return None
return encoding.JsonToMessage(
operation_metadata_type,
encoding.MessageToJson(operation.metadata))
def GetBuildFromOperation(operation, operation_metadata_type):
metadata = GetMetadataFromOperation(operation, operation_metadata_type)
if not metadata or not metadata.createVersionMetadata:
return None
return metadata.createVersionMetadata.cloudBuildId
def GetWarningsFromOperation(operation, operation_metadata_type):
metadata = GetMetadataFromOperation(operation, operation_metadata_type)
if not metadata:
return set()
return set(warning for warning in metadata.warning)
def WaitForOperation(operation_service, operation,
max_retries=None,
retry_interval=None,
operation_collection='appengine.apps.operations',
message=None,
poller=None):
"""Wait until the operation is complete or times out.
Args:
operation_service: The apitools service type for operations
operation: The operation resource to wait on
max_retries: Maximum number of times to poll the operation
retry_interval: Frequency of polling in seconds
operation_collection: The resource collection of the operation.
message: str, the message to display while progress tracker displays.
poller: AppEngineOperationPoller to poll with, defaulting to done.
Returns:
The operation resource when it has completed
Raises:
OperationError: if the operation contains an error.
OperationTimeoutError: when the operation polling times out
"""
poller = poller or AppEngineOperationPoller(operation_service)
if poller.IsDone(operation):
return poller.GetResult(operation)
operation_ref = resources.REGISTRY.ParseRelativeName(
operation.name,
operation_collection)
if max_retries is None:
max_retries = DEFAULT_OPERATION_MAX_TRIES - 1
if retry_interval is None:
retry_interval = DEFAULT_OPERATION_RETRY_INTERVAL
if message is None:
message = 'Waiting for operation [{}] to complete'.format(
operation_ref.RelativeName())
# Convert to milliseconds
retry_interval *= 1000
try:
completed_operation = waiter.WaitFor(
poller,
operation_ref,
message,
pre_start_sleep_ms=1000,
max_retrials=max_retries,
exponential_sleep_multiplier=1.0,
sleep_ms=retry_interval)
except waiter.TimeoutError:
raise OperationTimeoutError(('Operation [{0}] timed out. This operation '
'may still be underway.').format(
operation.name))
return completed_operation

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with region resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
class Region(object):
"""Value class representing a region resource."""
def __init__(self, region, standard, flexible, search_api):
self.region = region
self.standard = standard
self.flexible = flexible
self.search_api = search_api
@classmethod
def FromRegionResource(cls, region):
"""Create region from a google.cloud.location.Location message."""
flex = False
standard = False
search_api = False
region_id = region.labels.additionalProperties[0].value
for p in region.metadata.additionalProperties:
if p.key == 'flexibleEnvironmentAvailable' and p.value.boolean_value:
flex = True
elif p.key == 'standardEnvironmentAvailable' and p.value.boolean_value:
standard = True
elif p.key == 'searchApiAvailable' and p.value.boolean_value:
search_api = True
return cls(region_id, standard, flex, search_api)
def __str__(self):
envs = (
x[1] for x in [(self.standard, 'standard'),
(self.flexible, 'flexible'),
(self.search_api, 'search_api')] if x[0])
out = '{region: <13}'.format(region=self.region)
return out + ' (supports {envs})'.format(envs=' and '.join(envs))

View File

@@ -0,0 +1,770 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Library code to support App Engine Flex runtime builders.
The App Engine Flex platform runs a user's application that has been packaged
into a docker image. At the lowest level, the user provides us with a source
directory complete with Dockerfile, which we build into an image and deploy.
To make development easier, Google provides blessed language runtimes that the
user can extend in their Dockerfile to get a working base image for their
application. To further make development easier, we do not require users to
author their own Dockerfiles for "canonical" applications for each of the
Silver Languages.
In order for this to be possible, preprocessing must be done prior to the
Docker build to inspect the user's source code and automatically generate a
Dockerfile.
Flex runtime builders are a per-runtime pipeline that covers the full journey
from source directory to docker image. They are stored as templated .yaml files
representing CloudBuild Build messages. These .yaml files contain a series of
CloudBuild build steps. Additionally, the runtime root stores a `runtimes.yaml`
file which contains a list of runtime names and mappings to the corresponding
builder yaml files.
Such a builder will look something like this (note that <angle_brackets> denote
values to be filled in by the builder author, and $DOLLAR_SIGNS denote a
literal part of the template to be substituted at runtime):
steps:
- name: 'gcr.io/google_appengine/python-builder:<version>'
env: ['GAE_APPLICATION_YAML_PATH=${_GAE_APPLICATION_YAML_PATH}']
- name: 'gcr.io/cloud-builders/docker:<docker_image_version>'
args: ['build', '-t', '$_OUTPUT_IMAGE', '.']
images: ['$_OUTPUT_IMAGE']
To test this out in the context of a real deployment, do something like the
following (ls/grep steps just for illustrating where files are):
$ ls /tmp/runtime-root
runtimes.yaml python-v1.yaml
$ cat /tmp/runtime-root/runtimes.yaml
schema_version: 1
runtimes:
python:
target:
file: python-v1.yaml
$ gcloud config set app/use_runtime_builders true
$ gcloud config set app/runtime_builders_root file:///tmp/runtime-root
$ cd $MY_APP_DIR
$ grep 'runtime' app.yaml
runtime: python
$ grep 'env' app.yaml
env: flex
$ gcloud beta app deploy
A (possibly) easier way of achieving the same thing if you don't have a
runtime_builders_root set up for development yet:
$ cd $MY_APP_DIR
$ export _OUTPUT_IMAGE=gcr.io/$PROJECT/appengine/placeholder
$ gcloud container builds submit \
--config=<(envsubst < /path/to/cloudbuild.yaml) .
$ gcloud app deploy --image-url=$_OUTPUT_IMAGE
Or (even easier) use a 'custom' runtime:
$ cd $MY_APP_DIR
$ ls
cloudbuild.yaml app.yaml
$ rm -f Dockerfile
$ grep 'runtime' app.yaml
runtime: custom
$ gcloud beta app deploy
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import contextlib
import enum
import os
import re
from googlecloudsdk.api_lib.cloudbuild import cloudbuild_util
from googlecloudsdk.api_lib.cloudbuild import config as cloudbuild_config
from googlecloudsdk.api_lib.storage import storage_api
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.calliope import exceptions as calliope_exceptions
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import yaml
import six
import six.moves.urllib.error
import six.moves.urllib.parse
import six.moves.urllib.request
# "test-{ga,beta}" runtimes are canaries for unit testing
_ALLOWLISTED_RUNTIMES_GA = frozenset(
{'aspnetcore', 'php', 'nodejs', 'ruby', 'java',
re.compile(r'(python|python-.+)$'),
re.compile(r'(go|go1\..+)$'),
re.compile('^gs://'),
'test-ga', re.compile('test-re-[ab]')})
_ALLOWLISTED_RUNTIMES_BETA = frozenset(
_ALLOWLISTED_RUNTIMES_GA |
{'test-beta'})
class FileReadError(exceptions.Error):
"""Error indicating a file read operation failed."""
class ManifestError(exceptions.Error):
"""Error indicating a problem parsing or using the manifest."""
class ExperimentsError(exceptions.Error):
"""Error indicating a problem parsing or using the experiment config."""
class CloudBuildLoadError(exceptions.Error):
"""Error indicating an issue loading the runtime Cloud Build specification."""
class CloudBuildFileNotFound(CloudBuildLoadError):
"""Error indicating a missing Cloud Build file."""
class InvalidRuntimeBuilderURI(CloudBuildLoadError):
"""Error indicating that the runtime builder URI format wasn't recognized."""
def __init__(self, uri):
super(InvalidRuntimeBuilderURI, self).__init__(
'[{}] is not a valid runtime builder URI. '
'Please set the app/runtime_builders_root property to a URI with '
'either the Google Cloud Storage (`gs://`) or local file (`file://`) '
'protocol.'.format(uri))
class BuilderResolveError(exceptions.Error):
"""Error indicating that a build file could not be resolved."""
class RuntimeBuilderStrategy(enum.Enum):
"""Enum indicating when to use runtime builders."""
NEVER = 1
ALLOWLIST_BETA = 2 # That is, turned on for an allowed set of runtimes
ALLOWLIST_GA = 3 # That is, turned on for an allowed set of runtimes
ALWAYS = 4
def _GetAllowlist(self):
"""Return the allowlist of runtimes for this strategy.
The allowlist is kept as a constant within this module.
Returns:
list of str, the names of runtimes that are allowed for this strategy.
Raises:
ValueError: if this strategy is not allowlist-based.
"""
if self is self.ALLOWLIST_GA:
return _ALLOWLISTED_RUNTIMES_GA
elif self is self.ALLOWLIST_BETA:
return _ALLOWLISTED_RUNTIMES_BETA
raise ValueError(
'RuntimeBuilderStrategy {} is not an allowed strategy.'.format(self))
def _IsAllowed(self, runtime):
for allowlisted_runtime in self._GetAllowlist():
try:
if allowlisted_runtime.match(runtime):
return True
except AttributeError:
if runtime == allowlisted_runtime:
return True
return False
def ShouldUseRuntimeBuilders(self, runtime, needs_dockerfile):
"""Returns True if runtime should use runtime builders under this strategy.
For the most part, this is obvious: the ALWAYS strategy returns True, the
ALLOWLIST_${TRACK} strategies return True if the given runtime is in the
list of _ALLOWLISTED_RUNTIMES_${TRACK}, and the NEVER strategy returns
False.
However, in the case of 'custom' runtimes, things get tricky: if the
strategy *is not* NEVER, we return True only if there is no `Dockerfile` in
the current directory (this method assumes that there is *either* a
`Dockerfile` or a `cloudbuild.yaml` file), since one needs to get generated
by the Cloud Build.
Args:
runtime: str, the runtime being built.
needs_dockerfile: bool, whether the Dockerfile in the source directory is
absent.
Returns:
bool, whether to use the runtime builders.
Raises:
ValueError: if an unrecognized runtime_builder_strategy is given
"""
# For these strategies, if a user provides a 'custom' runtime, we use
# runtime builders unless there is a Dockerfile. For other strategies, we
# never use runtime builders with 'custom'.
if runtime == 'custom' and self in (self.ALWAYS,
self.ALLOWLIST_BETA,
self.ALLOWLIST_GA):
return needs_dockerfile
if self is self.ALWAYS:
return True
elif self is self.ALLOWLIST_BETA or self is self.ALLOWLIST_GA:
return self._IsAllowed(runtime)
elif self is self.NEVER:
return False
else:
raise ValueError('Invalid runtime builder strategy [{}].'.format(self))
def _Join(*args):
"""Join parts of a gs:// Cloud Storage or local file:// path."""
# URIs always uses '/' as separator, regardless of local platform.
return '/'.join([arg.strip('/') for arg in args])
@contextlib.contextmanager
def _Read(uri):
"""Read a file/object (local file:// or gs:// Cloud Storage path).
>>> with _Read('gs://builder/object.txt') as f:
... assert f.read() == 'foo'
>>> with _Read('file:///path/to/object.txt') as f:
... assert f.read() == 'bar'
Args:
uri: str, the path to the file/object to read. Must begin with 'file://' or
'gs://'
Yields:
a file-like context manager.
Raises:
FileReadError: If opening or reading the file failed.
InvalidRuntimeBuilderPath: If the path is invalid (doesn't begin with an
appropriate prefix).
"""
try:
if uri.startswith('file://'):
with contextlib.closing(six.moves.urllib.request.urlopen(uri)) as req:
yield req
elif uri.startswith('gs://'):
storage_client = storage_api.StorageClient()
object_ = storage_util.ObjectReference.FromUrl(uri)
with contextlib.closing(storage_client.ReadObject(object_)) as f:
yield f
else:
raise InvalidRuntimeBuilderURI(uri)
except (six.moves.urllib.error.HTTPError, six.moves.urllib.error.URLError,
calliope_exceptions.BadFileException) as e:
log.debug('', exc_info=True)
raise FileReadError(six.text_type(e))
class BuilderReference(object):
"""A reference to a specific cloudbuild.yaml file to use."""
def __init__(self, runtime, build_file_uri, deprecation_message=None):
"""Constructs a BuilderReference.
Args:
runtime: str, The runtime this builder corresponds to.
build_file_uri: str, The full URI of the build configuration or None if
this runtime existed but no longer can be built (deprecated).
deprecation_message: str, A message to print when using this builder or
None if not deprecated.
"""
self.runtime = runtime
self.build_file_uri = build_file_uri
self.deprecation_message = deprecation_message
def LoadCloudBuild(self, params):
"""Loads the Cloud Build configuration file for this builder reference.
Args:
params: dict, a dictionary of values to be substituted in to the
Cloud Build configuration template corresponding to this runtime
version.
Returns:
Build message, the parsed and parameterized Cloud Build configuration
file.
Raises:
CloudBuildLoadError: If the Cloud Build configuration file is unknown.
FileReadError: If reading the configuration file fails.
InvalidRuntimeBuilderPath: If the path of the configuration file is
invalid.
"""
if not self.build_file_uri:
raise CloudBuildLoadError(
'There is no build file associated with runtime [{runtime}]'
.format(runtime=self.runtime))
messages = cloudbuild_util.GetMessagesModule()
with _Read(self.build_file_uri) as data:
build = cloudbuild_config.LoadCloudbuildConfigFromStream(
data, messages=messages, params=params)
if build.options is None:
build.options = messages.BuildOptions()
build.options.substitutionOption = (
build.options.SubstitutionOptionValueValuesEnum.ALLOW_LOOSE)
for step in build.steps:
has_yaml_path = False
has_runtime_version = False
for env in step.env:
parts = env.split('=')
log.debug('Env var in build step: ' + str(parts))
if 'GAE_APPLICATION_YAML_PATH' in parts:
has_yaml_path = True
if 'GOOGLE_RUNTIME_VERSION' in parts:
has_runtime_version = True
if not has_yaml_path:
step.env.append(
'GAE_APPLICATION_YAML_PATH=${_GAE_APPLICATION_YAML_PATH}')
if not has_runtime_version and '_GOOGLE_RUNTIME_VERSION' in params:
step.env.append('GOOGLE_RUNTIME_VERSION=${_GOOGLE_RUNTIME_VERSION}')
return build
def WarnIfDeprecated(self):
"""Warns that this runtime is deprecated (if it has been marked as such)."""
if self.deprecation_message:
log.warning(self.deprecation_message)
def __eq__(self, other):
return (self.runtime == other.runtime and
self.build_file_uri == other.build_file_uri and
self.deprecation_message == other.deprecation_message)
def __ne__(self, other):
return not self.__eq__(other)
class Manifest(object):
"""Loads and parses a runtimes.yaml manifest.
To resolve a builder configuration file to use, a given runtime name is
looked up in this manifest. For each runtime, it either points to a
configuration file directly, or to another runtime. If it points to a runtime,
resolution continues until a configuration file is reached.
The following is the proto-ish spec for the yaml schema of the mainfest:
# Used to determine if this client can parse this manifest. If the number is
# less than or equal to the version this client knows about, it is compatible.
int schema_version; # Required
# The registry of all the runtimes that this manifest defines. The key of the
# map is the runtime name that appears in app.yaml.
<string, Runtime> runtimes {
# Determines which builder this runtime points to.
Target target {
oneof {
# A path relative to the manifest's location of the builder spec to use.
string file;
# Another runtime registered in this file that should be resolved and
# used for this runtime.
string runtime;
}
}
# Specifies deprecation information about this runtime.
Deprecation deprecation {
# A message to be displayed to the user on use of this runtime.
string message;
}
}
"""
SCHEMA_VERSION = 1
@classmethod
def LoadFromURI(cls, uri):
"""Loads a manifest from a gs:// or file:// path.
Args:
uri: str, A gs:// or file:// URI
Returns:
Manifest, the loaded manifest.
"""
log.debug('Loading runtimes manifest from [%s]', uri)
with _Read(uri) as f:
data = yaml.load(f, file_hint=uri)
return cls(uri, data)
def __init__(self, uri, data):
"""Use LoadFromFile, not this constructor directly."""
self._uri = uri
self._data = data
required_version = self._data.get('schema_version', None)
if required_version is None:
raise ManifestError(
'Unable to parse the runtimes manifest: [{}]'.format(uri))
if required_version > Manifest.SCHEMA_VERSION:
raise ManifestError(
'Unable to parse the runtimes manifest. Your client supports schema '
'version [{supported}] but requires [{required}]. Please update your '
'SDK to a later version.'.format(supported=Manifest.SCHEMA_VERSION,
required=required_version))
def Runtimes(self):
"""Get all registered runtimes in the manifest.
Returns:
[str], The runtime names.
"""
return list(self._data.get('runtimes', {}).keys())
def GetBuilderReference(self, runtime):
"""Gets the associated reference for the given runtime.
Args:
runtime: str, The name of the runtime.
Returns:
BuilderReference, The reference pointed to by the manifest, or None if the
runtime is not registered.
Raises:
ManifestError: if a problem occurred parsing the manifest.
"""
runtimes = self._data.get('runtimes', {})
current_runtime = runtime
seen = {current_runtime}
while True:
runtime_def = runtimes.get(current_runtime, None)
if not runtime_def:
log.debug('Runtime [%s] not found in manifest [%s]',
current_runtime, self._uri)
return None
new_runtime = runtime_def.get('target', {}).get('runtime', None)
if new_runtime:
# Runtime is an alias for another runtime, resolve the alias.
log.debug('Runtime [%s] is an alias for [%s]',
current_runtime, new_runtime)
if new_runtime in seen:
raise ManifestError(
'A circular dependency was found while resolving the builder for '
'runtime [{runtime}]'.format(runtime=runtime))
seen.add(new_runtime)
current_runtime = new_runtime
continue
deprecation_msg = runtime_def.get('deprecation', {}).get('message', None)
build_file = runtime_def.get('target', {}).get('file', None)
if build_file:
# This points to a build configuration file, create the reference.
full_build_uri = _Join(os.path.dirname(self._uri), build_file)
log.debug('Resolved runtime [%s] as build configuration [%s]',
current_runtime, full_build_uri)
return BuilderReference(
current_runtime, full_build_uri, deprecation_msg)
# There is no alias or build file. This means the runtime exists, but
# cannot be used. There might still be a deprecation message we can show
# to the user.
log.debug('Resolved runtime [%s] has no build configuration',
current_runtime)
return BuilderReference(current_runtime, None, deprecation_msg)
class Experiments(object):
"""Runtime experiment configs as read from a gs:// or a file:// source.
The experiment config file follows the following protoish schema:
# Used to determine if this client can parse this manifest. If the number is
# less than or equal to the version this client knows about, it is compatible.
int schema_version; # Required
# Map of experiments and their rollout percentage.
# The key is the name of the experiment, the value is an integer between 0
# and 100 representing the rollout percentage
# In case no experiments are defined, an empty 'experiments:' section needs to
# be present.
<String, Number> experiments
"""
SCHEMA_VERSION = 1
CONFIG_FILE = 'experiments.yaml'
TRIGGER_BUILD_SERVER_SIDE = 'trigger_build_server_side'
@classmethod
def LoadFromURI(cls, dir_uri):
"""Loads a runtime experiment config from a gs:// or file:// path.
Args:
dir_uri: str, A gs:// or file:// URI pointing to a folder that contains
the file called Experiments.CONFIG_FILE
Returns:
Experiments, the loaded runtime experiments config.
"""
uri = _Join(dir_uri, cls.CONFIG_FILE)
log.debug('Loading runtimes experiment config from [%s]', uri)
try:
with _Read(uri) as f:
data = yaml.load(f, file_hint=uri)
return cls(uri, data)
except FileReadError as e:
raise ExperimentsError(
'Unable to read the runtimes experiment config: [{}], error: {}'
.format(uri, e))
except yaml.YAMLParseError as e:
raise ExperimentsError(
'Unable to read the runtimes experiment config: [{}], error: {}'
.format(uri, e))
def __init__(self, uri, data):
"""Use LoadFromFile, not this constructor directly."""
self._uri = uri
self._data = data
required_version = self._data.get('schema_version', None)
if required_version is None:
raise ExperimentsError(
'Unable to parse the runtimes experiment config due to missing '
'schema_version field: [{}]'.format(uri))
if required_version > Experiments.SCHEMA_VERSION:
raise ExperimentsError(
'Unable to parse the runtimes experiments config. Your client '
'supports schema version [{supported}] but requires [{required}]. '
'Please update your SDK to a newer version.'.format(
supported=Manifest.SCHEMA_VERSION, required=required_version))
def Experiments(self):
"""Get all experiments and their rollout percentage.
Returns:
dict[str,int] Experiments and their rollout state.
"""
return self._data.get('experiments')
def GetExperimentPercentWithDefault(self, experiment, default=0):
"""Get the rollout percentage of an experiment or return 'default'.
Args:
experiment: the name of the experiment
default: the value to return if the experiment was not found
Returns:
int the percent of the experiment
"""
try:
return self._data.get('experiments')[experiment]
except KeyError:
return default
class Resolver(object):
"""Resolves the location of a builder configuration for a runtime.
There are several possible locations that builder configuration can be found
for a given runtime, and they are checked in order. Check GetBuilderReference
for the locations checked.
"""
# The name of the manifest in the builders root that registers the runtimes.
MANIFEST_NAME = 'runtimes.yaml'
BUILDPACKS_MANIFEST_NAME = 'runtimes_buildpacks.yaml'
# The name of the file in your local source for when you are using custom.
CLOUDBUILD_FILE = 'cloudbuild.yaml'
def __init__(self, runtime, source_dir, legacy_runtime_version,
use_flex_with_buildpacks=False):
"""Instantiates a resolver.
Args:
runtime: str, The name of the runtime to be resolved.
source_dir: str, The local path of the source code being deployed.
legacy_runtime_version: str, The value from runtime_config.runtime_version
in app.yaml. This is only used in legacy mode.
use_flex_with_buildpacks: bool, if true, use the build-image and
run-image built through buildpacks.
Returns:
Resolver, The instantiated resolver.
"""
self.runtime = runtime
self.source_dir = os.path.abspath(source_dir)
self.legacy_runtime_version = legacy_runtime_version
self.build_file_root = properties.VALUES.app.runtime_builders_root.Get(
required=True)
self.use_flex_with_buildpacks = use_flex_with_buildpacks
log.debug('Using use_flex_with_buildpacks [%s]',
self.use_flex_with_buildpacks)
log.debug('Using runtime builder root [%s]', self.build_file_root)
def GetBuilderReference(self):
"""Resolve the builder reference.
Returns:
BuilderReference, The reference to the builder configuration.
Raises:
BuilderResolveError: if this fails to resolve a builder.
"""
# Try builder resolution in the following order, stopping once one is found.
builder_def = (
self._GetReferenceCustom() or
self._GetReferencePinned() or
self._GetReferenceFromManifest() or
self._GetReferenceFromLegacy()
)
if not builder_def:
raise BuilderResolveError(
'Unable to resolve a builder for runtime: [{runtime}]'
.format(runtime=self.runtime))
return builder_def
def _GetReferenceCustom(self):
"""Tries to resolve the reference for runtime: custom.
If the user has an app.yaml with runtime: custom we will look in the root
of their source directory for a custom build pipeline named cloudbuild.yaml.
This should only be called if there is *not* a Dockerfile in the source
root since that means they just want to build and deploy that Docker image.
Returns:
BuilderReference or None
"""
if self.runtime == 'custom':
log.debug('Using local cloud build file [%s] for custom runtime.',
Resolver.CLOUDBUILD_FILE)
return BuilderReference(
self.runtime,
_Join('file:///' + self.source_dir.replace('\\', '/').strip('/'),
Resolver.CLOUDBUILD_FILE))
return None
def _GetReferencePinned(self):
"""Tries to resolve the reference for when a runtime is pinned.
Usually a runtime is looked up in the manifest and resolved to a
configuration file. The user does have the option of 'pinning' their build
to a specific configuration by specifying the absolute path to a builder
in the runtime field.
Returns:
BuilderReference or None
"""
if self.runtime.startswith('gs://'):
log.debug('Using pinned cloud build file [%s].', self.runtime)
return BuilderReference(self.runtime, self.runtime)
return None
def _GetReferenceFromManifest(self):
"""Tries to resolve the reference by looking up the runtime in the manifest.
Calculate the location of the manifest based on the builder root and load
that data. Then try to resolve a reference based on the contents of the
manifest.
Returns:
BuilderReference or None
"""
manifest_file_name = (
Resolver.BUILDPACKS_MANIFEST_NAME
if self.use_flex_with_buildpacks
else Resolver.MANIFEST_NAME)
manifest_uri = _Join(self.build_file_root, manifest_file_name)
log.debug('Using manifest_uri [%s]', manifest_uri)
try:
manifest = Manifest.LoadFromURI(manifest_uri)
return manifest.GetBuilderReference(self.runtime)
except FileReadError:
log.debug('', exc_info=True)
return None
def _GetReferenceFromLegacy(self):
"""Tries to resolve the reference by the legacy resolution process.
TODO(b/37542861): This can be removed after all runtimes have been migrated
to publish their builders in the manifest instead of <runtime>.version
files.
If the runtime is not found in the manifest, use legacy resolution. If the
app.yaml contains a runtime_config.runtime_version, this loads the file from
'<runtime>-<version>.yaml' in the runtime builders root. Otherwise, it
checks '<runtime>.version' to get the default version, and loads the
configuration for that version.
Returns:
BuilderReference or None
"""
if self.legacy_runtime_version:
# We already have a pinned version specified, just use that file.
return self._GetReferenceFromLegacyWithVersion(
self.legacy_runtime_version)
log.debug('Fetching version for runtime [%s] in legacy mode', self.runtime)
version_file_name = self.runtime + '.version'
version_file_uri = _Join(self.build_file_root, version_file_name)
try:
with _Read(version_file_uri) as f:
version = f.read().decode().strip()
except FileReadError:
log.debug('', exc_info=True)
return None
# Now that we resolved the default version, use that for the file.
log.debug('Using version [%s] for runtime [%s] in legacy mode',
version, self.runtime)
return self._GetReferenceFromLegacyWithVersion(version)
def _GetReferenceFromLegacyWithVersion(self, version):
"""Gets the name of configuration file to use for legacy mode.
Args:
version: str, The pinned version of the configuration file.
Returns:
BuilderReference
"""
file_name = '-'.join([self.runtime, version]) + '.yaml'
file_uri = _Join(self.build_file_root, file_name)
log.debug('Calculated builder definition using legacy version [%s]',
file_uri)
return BuilderReference(self.runtime, file_uri)
def FromServiceInfo(service, source_dir, use_flex_with_buildpacks=False):
"""Constructs a BuilderReference from a ServiceYamlInfo.
Args:
service: ServiceYamlInfo, The parsed service config.
source_dir: str, the source containing the application directory to build.
use_flex_with_buildpacks: bool, if true, use the build-image and
run-image built through buildpacks.
Returns:
RuntimeBuilderVersion for the service.
"""
runtime_config = service.parsed.runtime_config
legacy_version = (runtime_config.get('runtime_version', None)
if runtime_config else None)
resolver = Resolver(service.runtime, source_dir, legacy_version,
use_flex_with_buildpacks)
return resolver.GetBuilderReference()

View File

@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a registry for storing per-runtime information.
A registry is essentially a wrapper around a Python dict that stores a mapping
from (runtime, environment) to arbitrary data. Its main feature is that it
supports lookups by matching both the runtime and the environment.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from six.moves import map # pylint:disable=redefined-builtin
class RegistryEntry(object):
"""An entry in the Registry.
Attributes:
runtime: str or re.RegexObject, the runtime to be staged
envs: set(env.Environment), the environments to be staged
"""
def __init__(self, runtime, envs):
self.runtime = runtime
self.envs = envs
def _RuntimeMatches(self, runtime):
try:
return self.runtime.match(runtime)
except AttributeError:
return self.runtime == runtime
def _EnvMatches(self, env):
return env in self.envs
def Matches(self, runtime, env):
"""Returns True iff the given runtime and environment match this entry.
The runtime matches if it is an exact string match.
The environment matches if it is an exact Enum match or if this entry has a
"wildcard" (that is, None) for the environment.
Args:
runtime: str, the runtime to match
env: env.Environment, the environment to match
Returns:
bool, whether the given runtime and environment match.
"""
return self._RuntimeMatches(runtime) and self._EnvMatches(env)
def __hash__(self):
# Sets are unhashable; Environments are unorderable
return hash((self.runtime, sum(sorted(map(hash, self.envs)))))
def __eq__(self, other):
return self.runtime == other.runtime and self.envs == other.envs
def __ne__(self, other):
return not self.__eq__(other)
class Registry(object):
"""A registry to store values for various runtimes and environments.
The registry is a map from (runtime, app-engine-environment) to
user-specified values. As an example, storing Booleans for different
runtimes/environments would look like:
REGISTRY = {
RegistryEntry('php72', {env.STANDARD}): True,
RegistryEntry('php55', {env.STANDARD}): False,
RegistryEntry('nodejs8', {env.FLEX}): False,
}
Attributes:
mappings: dict, where keys are RegistryEntry objects and values can be
of any type
override: object or None; if specified, this value will always be returned
by Get()
default: object or None; if specified, will be returned if Get() could not
find a matching registry entry
"""
def __init__(self, mappings=None, override=None, default=None):
self.mappings = mappings or {}
self.override = override
self.default = default
def Get(self, runtime, env):
"""Return the associated value for the given runtime/environment.
Args:
runtime: str, the runtime to get a stager for
env: env, the environment to get a stager for
Returns:
object, the matching entry, or override if one was specified. If no
match is found, will return default if specified or None otherwise.
"""
if self.override:
return self.override
for entry, value in self.mappings.items():
if entry.Matches(runtime, env):
return value
if self.default is not None:
return self.default
else:
return None

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Package containing fingerprinting for all runtimes.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from gae_ext_runtime import ext_runtime
from googlecloudsdk.api_lib.app import ext_runtime_adapter
from googlecloudsdk.api_lib.app.runtimes import python
from googlecloudsdk.api_lib.app.runtimes import python_compat
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
RUNTIMES = [
# Note that ordering of runtimes here is very important and changes to the
# relative positions need to be tested carefully.
# Custom comes first, if we've got a Dockerfile this is a custom runtime.
ext_runtime_adapter.CoreRuntimeLoader('custom', 'Custom',
['custom']),
# Go's position is relatively flexible due to its orthogonal nature.
ext_runtime_adapter.CoreRuntimeLoader('go', 'Go', ['go', 'custom']),
ext_runtime_adapter.CoreRuntimeLoader('ruby', 'Ruby', ['ruby', 'custom']),
ext_runtime_adapter.CoreRuntimeLoader('nodejs', 'Node.js',
['nodejs', 'custom']),
ext_runtime_adapter.CoreRuntimeLoader('java', 'Java',
['java', 'java7', 'custom']),
python_compat,
# Python and PHP are last because they match if any .py or .php file is
# present.
ext_runtime_adapter.CoreRuntimeLoader('python', 'Python',
['python', 'custom']),
ext_runtime_adapter.CoreRuntimeLoader('php', 'PHP', ['php', 'custom']),
]
class UnidentifiedDirectoryError(exceptions.Error):
"""Raised when GenerateConfigs() can't identify the directory."""
def __init__(self, path):
"""Constructor.
Args:
path: (basestring) Directory we failed to identify.
"""
super(UnidentifiedDirectoryError, self).__init__(
'Unrecognized directory type: [{0}]'.format(path))
self.path = path
class ExtRuntimeError(exceptions.Error):
"""ext_runtime.Error errors are converted to this."""
class ConflictingConfigError(exceptions.Error):
"""Property in app.yaml conflicts with params passed to fingerprinter."""
class AlterConfigFileError(exceptions.Error):
"""Error when attempting to update an existing config file (app.yaml)."""
def __init__(self, inner_exception):
super(AlterConfigFileError, self).__init__(
'Could not alter app.yaml due to an internal error:\n{0}\n'
'Please update app.yaml manually.'.format(inner_exception))
def IdentifyDirectory(path, params=None):
"""Try to identify the given directory.
As a side-effect, if there is a config file in 'params' with a runtime of
'custom', this sets params.custom to True.
Args:
path: (basestring) Root directory to identify.
params: (ext_runtime.Params or None) Parameters passed through to the
fingerprinters. Uses defaults if not provided.
Returns:
(ext_runtime.Configurator or None) Returns a module if we've identified
it, None if not.
"""
if not params:
params = ext_runtime.Params()
# Parameter runtime has precedence
if params.runtime:
specified_runtime = params.runtime
elif params.appinfo:
specified_runtime = params.appinfo.GetEffectiveRuntime()
else:
specified_runtime = None
if specified_runtime == 'custom':
params.custom = True
for runtime in RUNTIMES:
# If we have an app.yaml, don't fingerprint for any runtimes that don't
# allow the runtime name it specifies.
if (specified_runtime and runtime.ALLOWED_RUNTIME_NAMES and
specified_runtime not in runtime.ALLOWED_RUNTIME_NAMES):
log.info('Not checking for [%s] because runtime is [%s]' %
(runtime.NAME, specified_runtime))
continue
try:
configurator = runtime.Fingerprint(path, params)
except ext_runtime.Error as ex:
raise ExtRuntimeError(ex.message)
if configurator:
return configurator
return None
def _GetModule(path, params=None, config_filename=None):
"""Helper function for generating configs.
Args:
path: (basestring) Root directory to identify.
params: (ext_runtime.Params or None) Parameters passed through to the
fingerprinters. Uses defaults if not provided.
config_filename: (str or None) Filename of the config file (app.yaml).
Raises:
UnidentifiedDirectoryError: No runtime module matched the directory.
ConflictingConfigError: Current app.yaml conflicts with other params.
Returns:
ext_runtime.Configurator, the configurator for the path
"""
if not params:
params = ext_runtime.Params()
config = params.appinfo
# An app.yaml exists, results in a lot more cases
if config and not params.deploy:
# Enforce --custom
if not params.custom:
raise ConflictingConfigError(
'Configuration file already exists. This command generates an '
'app.yaml configured to run an application on Google App Engine. '
'To create the configuration files needed to run this '
'application with docker, try `gcloud preview app gen-config '
'--custom`.')
# Check that current config is for MVM
if not config.IsVm():
raise ConflictingConfigError(
'gen-config is only supported for App Engine Flexible. Please '
'use "vm: true" in your app.yaml if you would like to use App Engine '
'Flexible to run your application.')
# Check for conflicting --runtime and runtime in app.yaml
if (config.GetEffectiveRuntime() != 'custom' and params.runtime is not None
and params.runtime != config.GetEffectiveRuntime()):
raise ConflictingConfigError(
'[{0}] contains "runtime: {1}" which conficts with '
'--runtime={2}.'.format(config_filename, config.GetEffectiveRuntime(),
params.runtime))
module = IdentifyDirectory(path, params)
if not module:
raise UnidentifiedDirectoryError(path)
return module
def GenerateConfigs(path, params=None, config_filename=None):
"""Identify runtime and generate config files for a directory.
If a runtime can be identified for the given directory, calls the runtime's
GenerateConfigs method, which writes configs to the directory.
Args:
path: (basestring) Root directory to identify.
params: (ext_runtime.Params or None) Parameters passed through to the
fingerprinters. Uses defaults if not provided.
config_filename: (str or None) Filename of the config file (app.yaml).
Raises:
ExtRuntimeError: if there was an error generating configs
Returns:
(bool): True if files were written
"""
module = _GetModule(path, params=params, config_filename=config_filename)
try:
return module.GenerateConfigs()
except ext_runtime.Error as ex:
raise ExtRuntimeError(ex.message)
def GenerateConfigData(path, params=None, config_filename=None):
"""Identify runtime and generate contents of config files for a directory.
If a runtime can be identified for the given directory, calls the runtime's
GenerateConfigData method, which generates the contents of config files.
Args:
path: (basestring) Root directory to identify.
params: (ext_runtime.Params or None) Parameters passed through to the
fingerprinters. Uses defaults if not provided.
config_filename: (str or None) Filename of the config file (app.yaml).
Raises:
ExtRuntimeError: if there was an error generating configs
Returns:
[ext_runtime.GeneratedFile] generated config files.
"""
module = _GetModule(path, params=params, config_filename=config_filename)
try:
return module.GenerateConfigData()
except ext_runtime.Error as ex:
raise ExtRuntimeError(ex.message)

Some files were not shown because too many files have changed in this diff Show More