feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,776 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility file that contains helpers for the Cloud TPU Execution groups."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import datetime
import os
import re
import sys
import time
from apitools.base.py import list_pager
from apitools.base.py.exceptions import HttpNotFoundError
from googlecloudsdk.api_lib.compute import base_classes
from googlecloudsdk.api_lib.compute.operations import poller
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.compute import scope as compute_scope
from googlecloudsdk.command_lib.compute import ssh_utils
from googlecloudsdk.command_lib.compute.instances import flags as instance_flags
from googlecloudsdk.command_lib.projects import util as p_util
from googlecloudsdk.command_lib.util.ssh import ssh
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.util import retry
from googlecloudsdk.core.util import times
import six
class DefaultArgs(object):
"""Helper to check if required flags are set and sets defaults if not."""
@staticmethod
def ValidateName(args):
"""Validates the name arg and sets defaults if values are not set."""
account = properties.VALUES.core.account.Get(required=True)
if account.find('@') == -1:
username = account
else:
username = account[0:account.find('@')]
args.name = args.name or username
@staticmethod
def ValidateZone(args):
"""Validates the zone arg and sets defaults if values are not set."""
args.zone = args.zone or properties.VALUES.compute.zone.Get(required=True)
class TPUNode(object):
"""Helper to create and modify TPU nodes."""
def __init__(self, release_track):
if release_track == base.ReleaseTrack.ALPHA:
self._api_version = 'v1alpha1'
else:
self._api_version = 'v1'
self.client = apis.GetClientInstance('tpu', self._api_version)
self.messages = apis.GetMessagesModule('tpu', self._api_version)
def _CreateDefaultNode(
self, accelerator_type, tf_version, preemptible, network):
node = self.messages.Node()
node.acceleratorType = accelerator_type
node.network = network
node.tensorflowVersion = tf_version
node.schedulingConfig = self.messages.SchedulingConfig(
preemptible=preemptible)
return node
def _GetTpuOperationRef(self, operation):
"""Get a resource reference to a long running operation."""
return resources.REGISTRY.ParseRelativeName(
operation.name, collection='tpu.projects.locations.operations')
def Create(
self, name, accelerator_type, tf_version, zone, preemptible, network):
"""Create builds and issues a request to create a TPU node.
Args:
name: Name of the TPU Node to be created.
accelerator_type: Slice type of TPU accelerator like 'v2-8', 'v2-32'.
tf_version: Tensorflow Version like '1.1', '1.5'.
zone: Zone to create the TPU Node in.
preemptible: Boolean argument, to create a Preemptible node.
network: The network to create the node in
Returns:
A TPU Create response which needs to be polled on.
"""
project = properties.VALUES.core.project.Get(required=True)
parent_ref = resources.REGISTRY.Parse(
zone,
params={'projectsId': project},
collection='tpu.projects.locations')
request = self.messages.TpuProjectsLocationsNodesCreateRequest(
parent=parent_ref.RelativeName(),
nodeId=name,
node=self._CreateDefaultNode(
accelerator_type, tf_version, preemptible, network))
operation = self.client.projects_locations_nodes.Create(request)
return self._GetTpuOperationRef(operation)
def WaitForOperation(self, operation_ref, message):
operation_poller = waiter.CloudOperationPoller(
self.client.projects_locations_nodes,
self.client.projects_locations_operations)
return waiter.WaitFor(operation_poller, operation_ref, message)
def WaitForOperationNoResources(self, operation_ref, message):
operation_poller = waiter.CloudOperationPollerNoResources(
self.client.projects_locations_operations)
return waiter.WaitFor(operation_poller, operation_ref, message)
def Delete(self, name, zone):
"""Deletes the TPU node with the given name."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
request = self.messages.TpuProjectsLocationsNodesDeleteRequest(
name=fully_qualified_node_name_ref.RelativeName())
operation = self.client.projects_locations_nodes.Delete(request)
return self._GetTpuOperationRef(operation)
def List(self, zone):
"""Retrieves all TPU Nodes."""
project = properties.VALUES.core.project.Get(required=True)
parent_ref = resources.REGISTRY.Parse(
zone,
params={'projectsId': project},
collection='tpu.projects.locations')
request = self.messages.TpuProjectsLocationsNodesListRequest(
parent=parent_ref.RelativeName())
return list_pager.YieldFromList(
service=self.client.projects_locations_nodes,
request=request,
method='List',
batch_size_attribute='pageSize',
field='nodes'
)
def Get(self, name, zone):
"""Retrieves the TPU node in the given zone."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
request = self.messages.TpuProjectsLocationsNodesGetRequest(
name=fully_qualified_node_name_ref.RelativeName())
return self.client.projects_locations_nodes.Get(request)
def LatestStableTensorflowVersion(self, zone):
"""Parses available Tensorflow versions to find the most stable version."""
project = properties.VALUES.core.project.Get(required=True)
parent_ref = resources.REGISTRY.Parse(
zone,
params={'projectsId': project},
collection='tpu.projects.locations')
request = self.messages.TpuProjectsLocationsTensorflowVersionsListRequest(
parent=parent_ref.RelativeName()
)
tf_versions = list_pager.YieldFromList(
service=self.client.projects_locations_tensorflowVersions,
request=request,
batch_size_attribute='pageSize',
field='tensorflowVersions')
parsed_tf_versions = []
for tf_version in tf_versions:
parsed_tf_versions.append(
TensorflowVersionParser.ParseVersion(tf_version.version))
sorted_tf_versions = sorted(parsed_tf_versions)
for version in sorted_tf_versions:
if version.is_nightly:
raise HttpNotFoundError('No stable release found', None, None)
if not version.modifier:
return version.VersionString()
raise HttpNotFoundError('No stable release found', None, None)
def IsRunning(self, node):
return node.state == self.messages.Node.StateValueValuesEnum.READY or (
node.state == self.messages.Node.StateValueValuesEnum.CREATING and
node.ipAddress)
def NodeName(self, node):
pattern = 'projects/(.*)/locations/(.*)/nodes/(.*)'
match = re.search(pattern, node.name, re.IGNORECASE)
if match:
return match.group(3)
return ''
class ComputePollerNoResources(poller.Poller):
"""Compute operations poller that does not create a resource."""
def __init__(self, resource_service, target_ref=None):
super(ComputePollerNoResources, self).__init__(
resource_service=resource_service, target_ref=target_ref)
def GetResult(self, operation):
"""Overrides."""
return None
class TensorflowVersionParser(object):
"""Helper to parse tensorflow versions."""
class ParseError(Exception):
"""Error raised with input is unabled to be parse as a TF version."""
class Result(object):
"""Helper to capture result of parsing the TF version."""
def __init__(self,
major=0,
minor=0,
patch=0,
is_nightly=False,
modifier=''):
self.major = major
self.minor = minor
self.patch = patch
self.is_nightly = is_nightly
self.modifier = modifier
def IsUnknown(self):
return self.major == 0 and self.minor == 0 and not self.is_nightly
def VersionString(self):
if self.is_nightly:
return 'nightly{}'.format(self.modifier)
if self.major == 0 and self.minor == 0:
return self.modifier
return '{}.{}{}'.format(self.major, self.minor, self.modifier)
def __hash__(self):
return hash(self.major) + hash(self.minor) + hash(self.patch) + hash(
self.is_nightly) + hash(self.modifier)
def __eq__(self, other):
return (self.major == other.major and
self.minor == other.minor and
self.patch == other.patch and
self.is_nightly == other.is_nightly and
self.modifier == other.modifier)
def __lt__(self, other):
# Both non-nightlies, non-unknowns
if not self.is_nightly and not other.is_nightly and not self.IsUnknown(
) and not other.IsUnknown():
if self.major != other.major:
return self.major > other.major
if self.minor != other.minor:
return self.minor > other.minor
if self.patch != other.patch:
return self.patch > other.patch
if not self.modifier:
return True
if not other.modifier:
return False
# Both nightlies
if self.is_nightly and other.is_nightly:
if not self.modifier:
return True
if not other.modifier:
return False
# Both unknown versions
if self.IsUnknown() and other.IsUnknown():
return self.modifier < other.modifier
# If one is an unknown version, sort after
if self.IsUnknown():
return False
if other.IsUnknown():
return True
if self.is_nightly:
return False
return True
_VERSION_REGEX = re.compile('^(\\d+)\\.(\\d+)(.*)$')
_NIGHTLY_REGEX = re.compile('^nightly(.*)$')
_PATCH_NUMBER_REGEX = re.compile('^\\.(\\d+)$')
@staticmethod
def ParseVersion(tf_version):
"""Helper to parse the tensorflow version into it's subcomponents."""
if not tf_version:
raise TensorflowVersionParser.ParseError('Bad argument: '
'tf_version is empty')
version_match = TensorflowVersionParser._VERSION_REGEX.match(tf_version)
nightly_match = TensorflowVersionParser._NIGHTLY_REGEX.match(tf_version)
if version_match is None and nightly_match is None:
return TensorflowVersionParser.Result(modifier=tf_version)
if version_match is not None and nightly_match is not None:
raise TensorflowVersionParser.ParseError(
'TF version error: bad version: {}'.format(tf_version))
if version_match:
major = int(version_match.group(1))
minor = int(version_match.group(2))
result = TensorflowVersionParser.Result(major=major, minor=minor)
if version_match.group(3):
patch_match = TensorflowVersionParser._PATCH_NUMBER_REGEX.match(
version_match.group(3))
if patch_match:
matched_patch = int(patch_match.group(1))
if matched_patch:
result.patch = matched_patch
else:
result.modifier = version_match.group(3)
return result
if nightly_match:
result = TensorflowVersionParser.Result(is_nightly=True)
if nightly_match.group(1):
result.modifier = nightly_match.group(1)
return result
class Instance(object):
"""Helper to create the GCE VM required to work with the TPU Node."""
def __init__(self, release_track):
holder = base_classes.ComputeApiHolder(release_track)
self.client = holder.client.apitools_client
self.messages = holder.client.messages
def _ImageFamilyFromTensorflowVersion(self, tf_version, use_dl_image):
"""Generates the image family from the tensorflow version."""
if tf_version == 'nightly':
return 'tf-nightly'
parsed = TensorflowVersionParser.ParseVersion(tf_version)
if parsed.modifier:
raise TensorflowVersionParser.ParseError('Invalid tensorflow version:{} '
'(non-empty modifier); please '
'set the --gce-image '
'flag'.format(tf_version))
if use_dl_image:
if parsed.major == 2:
return 'tf2-{}-{}-cpu'.format(parsed.major, parsed.minor)
else:
return 'tf-{}-{}-cpu'.format(parsed.major, parsed.minor)
# From TF 2.4, image family format uses patch format by default,
# e.g.: `tf-2-4-0` for TF version 2.4
if parsed.patch or (parsed.major >= 2 and parsed.minor >= 4):
return 'tf-{}-{}-{}'.format(parsed.major, parsed.minor, parsed.patch)
return 'tf-{}-{}'.format(parsed.major, parsed.minor)
def ResolveImageFromTensorflowVersion(self, tf_version, use_dl_image):
"""Queries GCE to find the right image for the given TF version."""
project = 'ml-images'
if use_dl_image:
project = 'deeplearning-platform-release'
image_family = self._ImageFamilyFromTensorflowVersion(
tf_version, use_dl_image)
request = self.messages.ComputeImagesGetFromFamilyRequest(
family=image_family, project=project)
image = self.client.images.GetFromFamily(request)
return image and image.selfLink
def BuildInstanceSpec(self,
name,
zone,
machine_type,
disk_size,
preemptible,
network,
use_with_notebook,
source_image=None):
"""Builds an instance spec to be used for Instance creation."""
disk = self.messages.AttachedDisk(
boot=True,
autoDelete=True,
initializeParams=self.messages.AttachedDiskInitializeParams(
sourceImage=source_image,
diskSizeGb=disk_size
))
project_number = p_util.GetProjectNumber(
properties.VALUES.core.project.Get(required=True))
network_interface = self.messages.NetworkInterface(
network='projects/{}/global/networks/{}'.format(
project_number, network),
accessConfigs=[self.messages.AccessConfig(
name='External NAT',
type=self.messages.AccessConfig.TypeValueValuesEnum.ONE_TO_ONE_NAT)]
)
metadata = [self.messages.Metadata.ItemsValueListEntry(
key='ctpu',
value=name)]
if use_with_notebook:
metadata.append(
self.messages.Metadata.ItemsValueListEntry(
key='proxy-mode', value='project_editors'))
service_account = self.messages.ServiceAccount(
email='default',
scopes=[
'https://www.googleapis.com/auth/devstorage.read_write',
'https://www.googleapis.com/auth/logging.write',
'https://www.googleapis.com/auth/monitoring.write',
'https://www.googleapis.com/auth/cloud-platform'
])
labels = self.messages.Instance.LabelsValue(additionalProperties=[
self.messages.Instance.LabelsValue.AdditionalProperty(
key='ctpu', value=name)
])
return self.messages.Instance(
name=name,
metadata=self.messages.Metadata(items=metadata),
machineType='zones/{}/machineTypes/{}'.format(zone, machine_type),
disks=[disk],
scheduling=self.messages.Scheduling(preemptible=preemptible),
networkInterfaces=[network_interface],
labels=labels,
serviceAccounts=[service_account])
def _GetComputeZoneOperationRef(self, operation):
"""Get a resource reference to a long running operation."""
return resources.REGISTRY.Parse(
operation.selfLink, collection='compute.zoneOperations')
def Create(self, name, zone, machine_type, disk_size, preemptible, gce_image,
network, use_with_notebook):
"""Issue request to create an Instance."""
request = self.messages.ComputeInstancesInsertRequest(
project=properties.VALUES.core.project.Get(required=True),
zone=zone,
instance=self.BuildInstanceSpec(
name, zone, machine_type, disk_size, preemptible, network,
use_with_notebook, gce_image))
operation = self.client.instances.Insert(request)
return self._GetComputeZoneOperationRef(operation)
def Stop(self, name, zone):
"""Issue request to stop the Instance."""
project = properties.VALUES.core.project.Get(required=True)
request = self.messages.ComputeInstancesStopRequest(
instance=name,
project=project,
zone=zone
)
operation = self.client.instances.Stop(request)
return self._GetComputeZoneOperationRef(operation)
def Start(self, name, zone):
"""Issue request to start the Instance."""
project = properties.VALUES.core.project.Get(required=True)
request = self.messages.ComputeInstancesStartRequest(
instance=name,
project=project,
zone=zone
)
operation = self.client.instances.Start(request)
return self._GetComputeZoneOperationRef(operation)
def WaitForOperation(self, operation_ref, message):
"""Wait for Instance operation to complete."""
operation_poller = poller.Poller(self.client.instances)
return waiter.WaitFor(operation_poller, operation_ref, message)
def WaitForOperationNoResources(self, operation_ref, message):
operation_poller = ComputePollerNoResources(self.client.instances)
return waiter.WaitFor(operation_poller, operation_ref, message)
def List(self, zone):
"""Retrieves all Instances created by Execution Group."""
project = properties.VALUES.core.project.Get(required=True)
request = self.messages.ComputeInstancesListRequest(
zone=zone, project=project)
instances = list_pager.YieldFromList(
service=self.client.instances,
request=request,
method='List',
field='items')
result_set = []
for instance in instances:
if self._VMCreatedByExecGroup(instance):
result_set.append(instance)
return result_set
def Get(self, instance_name, zone):
"""Retrieves the Instance data."""
project = properties.VALUES.core.project.Get(required=True)
request = self.messages.ComputeInstancesGetRequest(
zone=zone, project=project, instance=instance_name)
instance = self.client.instances.Get(request)
if self._VMCreatedByExecGroup(instance):
return instance
raise HttpNotFoundError(
'Instance:{} not found'.format(instance_name), None, None)
def _VMCreatedByExecGroup(self, instance):
if instance and instance.labels:
for label in instance.labels.additionalProperties:
if label.key == 'ctpu':
return True
return False
def IsRunning(self, instance):
return instance.status == self.messages.Instance.StatusValueValuesEnum.RUNNING
def Delete(self, name, zone):
"""Deletes the specified instance in the given zone and project."""
request = self.messages.ComputeInstancesDeleteRequest(
project=properties.VALUES.core.project.Get(required=True),
zone=zone,
instance=name
)
operation = self.client.instances.Delete(request)
return self._GetComputeZoneOperationRef(operation)
class SSH(object):
"""Helper class to SSH to the VM associated with the TPU node."""
def __init__(self, release_track):
holder = base_classes.ComputeApiHolder(release_track)
self.release_track = release_track
self.client = holder.client
self.resources = holder.resources
def _DefaultArgsForSSH(self, args):
# These arguments are not exposed to the user but are required in
# order to use the SSH Utils.
args.plain = None
args.strict_host_key_checking = 'no'
args.force_key_file_overwrite = None
args.ssh_key_file = None
return args
def _GetHostKeyFromInstance(self, zone, ssh_helper, instance):
"""Wrapper around SSH Utils to get the host keys for SSH."""
instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
[instance.name], compute_scope.ScopeEnum.ZONE, zone,
self.resources,
scope_lister=instance_flags.GetInstanceZoneScopeLister(self.client))[0]
project = ssh_helper.GetProject(self.client, instance_ref.project)
host_keys = ssh_helper.GetHostKeysFromGuestAttributes(
self.client, instance_ref, instance, project)
if host_keys is not None and not host_keys:
# Only display this message if there was an attempt to retrieve
# host keys but it was unsuccessful(yielded empty dict). If Guest
# Attributes is disabled, there is no attempt to retrieve host keys.
log.status.Print('Unable to retrieve host keys from instance metadata. '
'Continuing.')
return host_keys
def _GetSSHOptions(self, name, ssh_helper, instance, host_keys):
options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance),
strict_host_key_checking='no',
host_keys_to_add=host_keys)
os.environ['TPU_NAME'] = name
options['SendEnv'] = 'TPU_NAME'
return options
def _WaitForSSHKeysToPropagate(
self, ssh_helper, remote, identity_file, user, instance, options,
putty_force_connect=False):
"""Waits for SSH keys to propagate in order to SSH to the instance."""
ssh_helper.EnsureSSHKeyExists(
self.client, user, instance,
ssh_helper.GetProject(
self.client, properties.VALUES.core.project.Get(required=True)),
times.Now() + datetime.timedelta(seconds=300))
ssh_poller = ssh.SSHPoller(
remote=remote,
identity_file=identity_file, options=options, max_wait_ms=300*1000)
try:
ssh_poller.Poll(
ssh_helper.env,
putty_force_connect=putty_force_connect)
except retry.WaitException:
raise ssh_utils.NetworkError()
def SSHToInstance(self, args, instance):
"""Helper to manage authentication followed by SSH to the instance."""
args = self._DefaultArgsForSSH(args)
external_nat = ssh_utils.GetExternalIPAddress(instance)
log.status.Print(
'Trying to SSH to VM with NAT IP:{}'.format(external_nat))
args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE
ssh_helper = ssh_utils.BaseSSHCLIHelper()
ssh_helper.Run(args)
identity_file = ssh_helper.keys.key_file
user, _ = ssh_utils.GetUserAndInstance(args.name)
host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper, instance)
options = self._GetSSHOptions(args.name, ssh_helper,
instance, host_keys)
public_key = ssh_helper.keys.GetPublicKey().ToEntry(include_comment=True)
oslogin_state = ssh.GetOsloginState(
instance,
ssh_helper.GetProject(
self.client, properties.VALUES.core.project.Get(required=True)),
user,
public_key,
None,
self.release_track,
username_requested=False,
messages=self.client.messages)
user = oslogin_state.user
remote = ssh.Remote(external_nat, user)
# TODO(b/35355795): Don't force connect in general.
# At a minimum, avoid injecting 'y' if PuTTY will prompt for a 2FA
# authentication method (since we know that won't work), or if the user has
# disabled the property.
putty_force_connect = (
not oslogin_state.oslogin_2fa_enabled and
properties.VALUES.ssh.putty_force_connect.GetBool())
if not oslogin_state.oslogin_enabled:
self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file, user,
instance, options, putty_force_connect)
extra_flags = []
# Ctpu seems to be forwarding some other ports on what
# seems like the TPU node. Need to understand better before enabling.
if args.forward_ports:
extra_flags.extend(
['-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888'])
ssh_cmd_args = {
'remote': remote,
'identity_file': identity_file,
'options': options,
'extra_flags': extra_flags
}
cmd = ssh.SSHCommand(**ssh_cmd_args)
max_attempts = 10
sleep_interval = 30
# Since the instance was just created, it can take a while for the instance
# to be ready to accept ssh connections, therefore retry up to 5m. Doesn't
# need to be backed off, regular interval retry is sufficient since we
# aren't looking to throttle.
for i in range(max_attempts):
try:
log.status.Print('SSH Attempt #{}...'.format(i))
# Errors from SSH itself result in an ssh.CommandError being raised
return_code = cmd.Run(
ssh_helper.env,
putty_force_connect=putty_force_connect)
if return_code:
# This is the return code of the remote command.
# Problems with SSH itself will result in ssh.CommandError
# being raised above.
sys.exit(return_code)
except ssh.CommandError as e:
if i == max_attempts - 1:
raise e
log.status.Print(
'Retrying: SSH command error: {}'.format(six.text_type(e)))
time.sleep(sleep_interval)
continue
break
class ResourceManager(object):
"""Helper to interact with Cloud Resource Manager and related ACLs."""
logging_role = 'roles/logging.logWriter'
storage_role = 'roles/storage.admin' # Note storage.objectAdmin does not work
# in certain cases, and thus we need
# roles/storage.admin.
tpu_service_agent = 'roles/tpu.serviceAgent'
def __init__(self):
self._api_version = 'v1'
self.client = apis.GetClientInstance(
'cloudresourcemanager', self._api_version)
self.messages = apis.GetMessagesModule(
'cloudresourcemanager', self._api_version)
def AddTpuUserAgent(self, tpu_user_agent):
"""AddTPUUserAgent adds the TPU user agent to enable Cloud Storage access and send logging."""
project = properties.VALUES.core.project.Get(required=True)
get_iam_policy_request = self.messages.CloudresourcemanagerProjectsGetIamPolicyRequest(
resource=project)
policy = self.client.projects.GetIamPolicy(get_iam_policy_request)
policy = self._AddAgentToPolicy(policy, tpu_user_agent)
if policy is None:
log.status.Print('TPU Service account:{} has already been enabled'
.format(tpu_user_agent))
else:
set_iam_policy_request = self.messages.CloudresourcemanagerProjectsSetIamPolicyRequest(
resource=project,
setIamPolicyRequest=self.messages.SetIamPolicyRequest(
policy=policy
))
self.client.projects.SetIamPolicy(set_iam_policy_request)
log.status.Print(
'Added Storage and Logging permissions to TPU Service Account:{}'
.format(tpu_user_agent))
def _AddAgentToPolicy(self, policy, tpu_user_agent):
"""Adds the tpuUserAgent to the policy and return it."""
logging_binding = None
storage_binding = None
tpu_member_str = 'serviceAccount:{}'.format(tpu_user_agent)
for binding in policy.bindings:
if binding.role == self.logging_role:
logging_binding = binding
if binding.role == self.storage_role:
storage_binding = binding
# Skip checking bindings if this is the tpuServiceAgent role.
if binding.role != self.tpu_service_agent:
# Check if the tpuMemberStr is already in a binding.
for member in binding.members:
if member == tpu_member_str:
# The TPU service account has already been enabled. Make no
# modifications.
return None
if logging_binding is None:
logging_binding = self.messages.Binding(role=self.logging_role)
policy.bindings.append(logging_binding)
if storage_binding is None:
storage_binding = self.messages.Binding(role=self.storage_role)
policy.bindings.append(storage_binding)
logging_binding.members.append(tpu_member_str)
storage_binding.members.append(tpu_member_str)
return policy

View File

@@ -0,0 +1,272 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flag Utilities for cloud tpu commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
def GetTPUNameArg():
return base.Argument(
'tpu_id',
help='Name of the TPU.')
def GetDescriptionFlag():
return base.Argument(
'--description',
help='Specifies a text description of the TPU.')
def GetAcceleratorTypeFlag():
"""Set argument for choosing the TPU Accelerator type."""
return base.Argument(
'--accelerator-type',
default='v2-8',
type=lambda x: x.lower(),
required=False,
help="""\
TPU accelerator type for the TPU.
If not specified, this defaults to `v2-8`.
""")
def GetVersionFlag():
"""Set argument for choosing the TPU Tensor Flow version."""
return base.Argument(
'--version',
required=True,
help="""\
TensorFlow version for the TPU, such as `1.6`. For a list of available
TensorFlow versions please see https://www.tensorflow.org/versions/.
""")
def GetRangeFlag():
"""Set Cidr Range for Cloud TPU."""
return base.Argument(
'--range',
required=False,
help="""\
CIDR Range for the TPU.
The IP range that the TPU will select an IP address from.
Must be in CIDR notation and a `/29` range, for example `192.168.0.0/29`.
Errors will occur if the CIDR range has already been used for a
currently existing TPU, the CIDR range conflicts with any networks
in the user's provided network, or the provided network is peered with
another network that is using that CIDR range.
""")
def AddPreemptibleFlag(parser):
return parser.add_argument(
'--preemptible',
required=False,
action='store_true',
default=False,
help="""\
Create a preemptible Cloud TPU, instead of a normal (non-preemptible) Cloud TPU. A
preemptible Cloud TPU costs less per hour, but the Cloud TPU service can stop/terminate
the node at any time.
""")
def AddTpuNameArg(parser):
return parser.add_argument(
'execution_group_name',
help="""\
The execution group name to delete. """
)
def AddTpuNameOverrideArg(parser):
return parser.add_argument(
'--name',
help="""\
Override the name to use for VMs and TPUs (defaults to your username). """,
)
def AddPreemptibleVmFlag(parser):
return parser.add_argument(
'--preemptible-vm',
required=False,
action='store_true',
default=False,
help="""\
Create a preemptible Compute Engine VM, instead of a normal (non-preemptible) VM.
A preemptible VM costs less per hour, but the Compute Engine service can terminate the
instance at any time.
""")
def AddTfVersionFlag(parser, help_text_override=None):
help_text = """\
Set the version of TensorFlow to use when creating the Compute Engine VM and the Cloud TPU.
(It defaults to auto-selecting the latest stable release.)
"""
return parser.add_argument(
'--tf-version',
required=False,
help=help_text_override or help_text
)
def AddTfVersionFlagForResume(parser):
help_text_override = """\
Set the version of TensorFlow to the version originally set when creating the suspended Cloud TPU and Compute Engine VM .
(It defaults to auto-selecting the latest stable release.)
"""
AddTfVersionFlag(parser, help_text_override)
def AddVmOnlyFlag(parser):
return parser.add_argument(
'--vm-only',
action='store_true',
required=False,
default=False,
help="""\
Do not allocate a TPU, only allocate a VM (useful if you're not ready to run on a TPU yet).
""")
def AddTpuOnlyFlag(parser, help_text_override=None):
help_text = """\
Do not allocate a VM, only allocate a TPU. To be used after the command has been run with a --vm-only flag
and the user is ready to run on a TPU. Ensure that the name matches the name passed in when creating with the --vm-only flag.
"""
return parser.add_argument(
'--tpu-only',
action='store_true',
required=False,
default=False,
help=help_text_override or help_text)
def AddTpuOnlyFlagForDelete(parser):
help_text_override = """\
Do not delete VM, only delete the TPU.
"""
return AddTpuOnlyFlag(parser, help_text_override)
def AddDeepLearningImagesFlag(parser):
return parser.add_argument(
'--use-dl-images',
action='store_true',
required=False,
default=False,
help="""\
Use Deep Learning VM Images (see docs - https://cloud.google.com/deep-learning-vm/) instead
of TPU-specific machine images. Defaults to TPU-specific images. This
value is set to true automatically if the --use-with-notebook flag is
set to true.
""")
def AddDryRunFlag(parser):
return parser.add_argument(
'--dry-run',
required=False,
action='store_true',
default=False,
help="""\
Do not make changes; print only what would have happened.
""")
def AddPortForwardingFlag(parser):
return parser.add_argument(
'--forward-ports',
action='store_false',
required=False,
help="""\
Automatically forward useful ports from the Compute Engine VM to your local
machine. The ports forwarded are: 6006 (tensorboard), 8888 (jupyter notebooks),
8470 (TPU port), 8466 (TPU profiler port).
""")
def AddGceImageFlag(parser):
return parser.add_argument(
'--gce-image',
help="""\
Override the automatically chosen Compute Engine Image. Use this flag when you're using
your own custom images instead of the provided ones with TensorFlow pre-installed.
""")
def AddDiskSizeFlag(parser):
return parser.add_argument(
'--disk-size',
default='250GB',
type=arg_parsers.BinarySize(
lower_bound='20GB',
upper_bound='2000GB',
suggested_binary_size_scales=['GB']),
help="""\
Configures the root volume size of your Compute Engine VM (in GB). The
minimum size is 20GB and the maximum is 2000GB. Specified value must be an
integer multiple of Gigabytes.
""")
def AddMachineTypeArgs(parser):
return parser.add_argument(
'--machine-type',
default='n1-standard-1',
help="""\
Specifies the machine type used for the Compute Engine VM. To get a
list of available machine types, run 'gcloud compute
machine-types list'. If unspecified, the default type is n1-standard-1.
""")
def AddUseWithNotebook(parser):
return parser.add_argument(
'--use-with-notebook',
action='store_true',
required=False,
default=False,
help="""\
Allow Compute Engine VM to be recognized by Cloud AI Notebooks. This
automatically sets the content of the flag --use-dl-images flag to be
true.
"""
)
def AddNetworkArgs(parser, help_text_override=None):
help_text = """\
Specifies the network the Cloud TPU and associated VM should be created in.
If unspecified, the network "default" is picked.
"""
return parser.add_argument(
'--network',
default='default',
help=help_text_override or help_text)
def AddNetworkArgsForResume(parser):
help_text_override = """\
Set to the network that was originally used creating the suspended Cloud TPU
and Compute Engine VM. (It defaults to using the 'default' network.)
"""
return AddNetworkArgs(parser, help_text_override)

View File

@@ -0,0 +1,199 @@
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
accelerator_type:
api_field: node.acceleratorType
arg_name: accelerator-type
required: false
help_text: |
TPU accelerator type for the TPU.
If not specified, this defaults to `v2-8`.
For a list of available accelerator types run:
`{parent_command} accelerator-types list`
type: googlecloudsdk.command_lib.util.hooks.types:LowerCaseType
default: 'v2-8'
image:
arg_name: image
required: false
help_text: |
Specifies the full URI of the machine image to use for creating the TPU VM's boot disk.
If specified, this will override the boot image that would normally be used by the
specified `--runtime-version`.
description:
api_field: node.description
arg_name: description
required: false
help_text: |
Specifies a text description of the TPU.
network:
api_field: node.network
arg_name: network
required: false
help_text: |
Specifies the network that this TPU will be a part of.
default: 'default'
version:
api_field: node.tensorflowVersion
arg_name: version
required: true
help_text: |
TensorFlow version for the TPU, such as `1.14`. For a list of available
TensorFlow versions please see https://www.tensorflow.org/versions/.
range: &range
api_field: node.cidrBlock
arg_name: range
help_text: |
CIDR Range for the TPU.
The IP range that the TPU will select an IP address from.
Must be in CIDR notation and a `/29` range, for example
`192.168.0.0/29`. Errors will occur if the CIDR range has already been
used for a currently existing TPU, the CIDR range conflicts with any
networks in the user's provided network, or the provided network is
peered with another network that is using that CIDR range.
preemptible:
api_field: node.schedulingConfig.preemptible
arg_name: preemptible
type: bool
default: false
required: false
help_text: |
If provided, the TPU will be preemptible and time-limited. It may be
preempted to free up resources for standard TPUs, and will only be able
to run for a limited amount of time.
Preemptible TPUs cannot be restarted.
service_networking:
api_field: node.useServiceNetworking
arg_name: use-service-networking
type: bool
default: false
help_text: |-
If provided, the TPU will be configured via the Service Networking (SN) API instead of
using a CIDR range. The Service Networking API should be enabled on the project before
creating the TPU.
For more information on Service Networking see https://cloud.google.com/service-infrastructure/docs/service-networking/getting-started.
reserved:
api_field: node.schedulingConfig.reserved
arg_name: reserved
type: bool
default: false
required: false
help_text: |
When specified, will attempt to create the TPU node under reservations made in the current
project. The reservations can be made separately but used in aggregated form. i.e., the user can
make a reservation of 128 V2 TPUs and later on make another reservation of 128 V2 TPUs then
creates a v2-256 TPU instance. If there exists no reservation or not sufficient amount of
reserved cores under the project, the request will fail due to lack of capacity.
metadata:
arg_name: metadata
metavar: KEY=VALUE
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
help_text: |
List of comma-separated metadata key-value pairs for the Cloud TPU VM node.
Example: `--metadata='key1=value1,key2=value2'`
metadata_from_file:
arg_name: metadata-from-file
metavar: KEY=VALUE
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
help_text: |
Same as `--metadata` except the value for the entry will be read from a local file.
Example: `--metadata-from-file='key1=value1.txt'`
update_metadata:
arg_name: update-metadata
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
metavar: KEY=VALUE
help_text: |
List of comma-separated metadata key-value pairs for the Cloud TPU VM node. If a key exists, its
value is modified. Otherwise, a new key-value pair is created.
clear_metadata:
arg_name: clear-metadata
type: bool
default: false
required: false
help_text: |
Remove all metadata. If `--update-metadata` is also specified then `--clear-metadata` is applied
first.
For example, to remove all metadata:
$ {command} --clear-metadata
To remove all existing metadata and create two new metadata key-value pairs, 'foo=bar' and
'baz=qux':
$ {command} --clear-metadata --update-metadata foo=bar,baz=qux
remove_metadata:
arg_name: remove-metadata
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
metavar: KEY
help_text: |
List of comma-separated metadata keys to remove. If a key does not exist it is silently ignored.
If `--update-metadata` is also specified then `--update-metadata` is applied first.
topology:
arg_name: topology
help_text: |
Chip topology for TPU.
type: googlecloudsdk.command_lib.util.hooks.types:LowerCaseType
type:
arg_name: type
help_text: |
Type of TPU.
choices:
- arg_value: v2
enum_value: V2
- arg_value: v3
enum_value: V3
- arg_value: v4
enum_value: V4
- arg_value: v5litepod
enum_value: V5LITE_POD
- arg_value: v5p
enum_value: V5P
- arg_value: v6e
enum_value: V6E
ga-type:
arg_name: type
help_text: |
Type of TPU.
choices:
- arg_value: v2
enum_value: V2
- arg_value: v3
enum_value: V3
- arg_value: v4
enum_value: V4

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,375 @@
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
node_id:
arg_name: node-id
required: true
help_text: |
Unqualified node identifier used to identify the node in the project once provisioned.
To request a resource with multiple nodes, in place of `--node-id`, use `--node-count` to
specify the number of nodes and optionally use `--node-prefix` to specify the prefix for each
node.
node_count:
arg_name: node-count
required: true
type: int
help_text: |
The number of nodes in a multislice provision, also used to generate the qualified name for
nodes in the provision. Value must be greater than 1.
node_prefix:
arg_name: node-prefix
help_text: |
Node prefix used to generate the qualified name of each node the multislice node provision. If
not supplied, the queued resource id will be used as the prefix.
Must also specify `--node-count`.
workload_type:
arg_name: workload-type
help_text: |
Specifies the workload type for the multi-node TPUs.
accelerator_type:
arg_name: accelerator-type
type: googlecloudsdk.command_lib.util.hooks.types:LowerCaseType
help_text: |
Accelerator type for the TPU.
runtime_version:
arg_name: runtime-version
required: true
help_text: |
Runtime version for the TPU, such as `tpu-ubuntu2204-base`.
best_effort:
arg_name: best-effort
type: bool
action: store_true
help_text: |
If provided, the Node requested here may be scheduled at the 'best effort' tier.
spot:
arg_name: spot
type: bool
action: store_true
help_text: |
If provided, the Node requested here will be created as Spot VMs.
guaranteed:
arg_name: guaranteed
type: bool
action: store_true
help_text: |
If provided, the Node requested here will only be scheduled at the 'guaranteed' tier.
provisioning_model:
arg_name: provisioning-model
help_text: |
If provided, the resource will be provisioned with the specified provisioning model.
max_run_duration:
arg_name: max-run-duration
type: googlecloudsdk.core.util.times:ParseDuration
processor: googlecloudsdk.core.util.times:FormatDurationForJson
api_field: queuedResource.runDuration.maxRunDuration
help_text: |
A duration after which the resource will be terminated. Used with flex-start.
See $ gcloud topic datetimes for information on duration formats.
min_duration:
arg_name: min-duration
api_field: queuedResource.guaranteed.minDuration
help_text: |
The minimum period of time the Node is needed. If specified, the requested Node will only
be scheduled if there is sufficient capacity for the given duration.
If this flag is set the guaranteed flag is also set.
reserved:
arg_name: reserved
type: bool
action: store_true
default: null
help_text: |
Specifies the request should be scheduled on reserved capacity.
If `--reservation-host-project`, `--reservation-host-folder`, or
`--reservation-host-organization` are present then this flag has no effect.
valid_after_duration:
arg_name: valid-after-duration
type: googlecloudsdk.core.util.times:ParseDuration
processor: googlecloudsdk.core.util.times:FormatDurationForJson
api_field: queuedResource.queueingPolicy.validAfterDuration
help_text: |
A duration before which the TPU must not be provisioned, relative to the current time.
See $ gcloud topic datetimes for information on duration formats.
valid_after_time:
arg_name: valid-after-time
type: googlecloudsdk.core.util.times:ParseDateTime
processor: googlecloudsdk.core.util.times:FormatDateTime
api_field: queuedResource.queueingPolicy.validAfterTime
help_text: |
An absolute time before which the TPU must not be provisioned.
See $ gcloud topic datetimes for information on duration formats.
valid_until_duration:
arg_name: valid-until-duration
type: googlecloudsdk.core.util.times:ParseDuration
processor: googlecloudsdk.core.util.times:FormatDurationForJson
api_field: queuedResource.queueingPolicy.validUntilDuration
help_text: |
A duration after which the TPU must not be provisioned, relative to the current time.
See $ gcloud topic datetimes for information on duration formats.
valid_until_time:
arg_name: valid-until-time
type: googlecloudsdk.core.util.times:ParseDateTime
processor: googlecloudsdk.core.util.times:FormatDateTime
api_field: queuedResource.queueingPolicy.validUntilTime
help_text: |
An absolute time after which resources must not be created.
See $ gcloud topic datetimes for information on duration formats.
internal_ips:
arg_name: internal-ips
type: bool
action: store_true
default: false # note that user-facing flag is inverse of API enable_external_ips flag
help_text: |
Indicates that the IP addresses for the node should be internal. The default is that external IP
addresses will be associated with the TPU workers.
reservation_host_project:
arg_name: reservation-host-project
help_text: |
The project hosting the reservation that the TPU should use. Only one reservation host entity
may be specified.
reservation_host_folder:
arg_name: reservation-host-folder
help_text: |
The folder hosting the reservation that the TPU should use. Only one reservation host entity
may be specified.
reservation_host_organization:
arg_name: reservation-host-organization
help_text: |
The organization hosting the reservation that the TPU should use. Only one reservation host
entity may be specified.
force:
arg_name: force
type: bool
help_text: |
If set to true, any nodes in this queued resource will also be deleted.
Otherwise, the request will only work if the queued resource has no nodes.
network:
arg_name: network
help_text: |
Network that this TPU will be a part of.
default: 'default'
subnetwork:
arg_name: subnetwork
help_text: |
Subnetwork that this TPU will be a part of.
service_account:
arg_name: service-account
help_text: |
Email address of the service account. If empty, default Google Compute Engine service
account will be used.
service_account_scopes:
arg_name: scopes
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
help_text: |
List of comma-separated scopes to be made available for the service account.
tags:
arg_name: tags
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
help_text: |
Tags to apply to the TPU Node. Tags are used to identify valid sources or
targets for network firewalls. See https://cloud.google.com/vpc/docs/add-remove-network-tags for
more details.
data_disks:
api_field: queuedResource.tpu.nodeSpec.node.dataDisks
arg_name: data-disk
type:
arg_dict:
flatten: false
spec:
- api_field: sourceDisk
arg_name: source
type: str
required: true
- api_field: mode
arg_name: mode
type: str
required: false
choices:
- arg_value: read-write
enum_value: READ_WRITE
- arg_value: read-only
enum_value: READ_ONLY
required: false
help_text: |
Additional data disks for the TPU VM.
This flag must be repeated to provide multiple data disks. For example:
$ {command} --data-disk source=projects/my-project/zones/us-central1-c/disks/my-disk,mode=read-only
The following keys are allowed:
*source*::: Specifies the full path to an existing disk. Required. The disk must be in the same zone.
*mode*::: Specifies the mode in which to attach this disk. Valid options are 'read-write',
'read-only'. If not specified, the default is 'read-write'.
description:
arg_name: description
required: false
help_text: |
Text description of the TPU.
labels:
arg_name: labels
metavar: KEY=VALUE
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
help_text: |
Resource labels to represent user-provided metadata. See
https://cloud.google.com/compute/docs/labeling-resources for details.
range: &range
arg_name: range
help_text: |
CIDR range for the TPU.
The IP range that the TPU will select an IP address from.
Must be in CIDR notation and a `/29` range, for example
`192.168.0.0/29`. Errors will occur if the CIDR range has already been
used for a currently existing TPU, the CIDR range conflicts with any
networks in the user's provided network, or the provided network is
peered with another network that is using that CIDR range.
enable_secure_boot:
arg_name: shielded-secure-boot
type: bool
default: false
required: false
help_text: |
Specifies that the TPU instances are created with secure boot enabled. This implicitly makes
them Shielded VM instances.
autocheckpoint_enabled:
arg_name: autocheckpoint-enabled
api_field: queuedResource.tpu.nodeSpec.node.autocheckpointEnabled
type: bool
default: false
required: false
help_text: |
Specifies that the TPU node(s) are created with the Autocheckpoint feature enabled.
boot_disk:
arg_name: boot-disk
required: false
metavar: KEY=VALUE
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
help_text: |
Specifies the boot disk configuration.
$ {command} \
--boot-disk kms-key=<full_kms_key_name_here>
The following keys are allowed:
*kms-key*::: Specifies the fully qualified Cloud KMS cryptokey name
which will be used to protect the disk. KMS cryptokey name format:
projects/<kms-project>/locations/<kms-location>/keyRings/<kms-keyring>/cryptoKeys/<key-name>
ALPHA:
help_text: |
Specifies the boot disk configuration.
$ {command} \
--boot-disk confidential-compute=True,kms-key=<full_kms_key_name_here>
The following keys are allowed:
*confidential-compute*::: Create the boot disk in confidential compute mode.
CMEK layer is required.
*kms-key*::: Specifies the fully qualified Cloud KMS cryptokey name
which will be used to protect the disk. KMS cryptokey name format:
projects/<kms-project>/locations/<kms-location>/keyRings/<kms-keyring>/cryptoKeys/<key-name>
network_config:
api_field: queuedResource.tpu.nodeSpec.node.networkConfigs
arg_name: network-config
type:
arg_dict:
flatten: false
spec:
- api_field: network
arg_name: network
type: str
required: false
- api_field: subnetwork
arg_name: subnetwork
type: str
required: false
- api_field: enableExternalIps
arg_name: external-ips
required: false
type: bool
required: false
help_text: |
Specify a full network_config for the TPU. network or subnetwork must be specified.
This flag may be repeated to provide multiple networks. For example:
$ {command} \
--network-config network=example_network,internal-ips=true
The following keys are allowed:
*network*::: Specify the network to be used by this TPU.
*subnetwork*::: Specify the network to be used by this TPU. If not specified,
this defaults to the subnetwork by the same name as the network.
*external-ips*::: Indicate that the IP addresses for the node should be exteranl. The default will be
only internal IP addresses.
reservation:
api_field: queuedResource.reservationName
arg_name: reservation
type: str
required: false
help_text: |
The name of the reservation. This can either be the full name or just the name.
However, a full name is required if the reservation is not in the default project.
Full reservation name format is:
projects/<project>/locations/<location>/reservations/<reservation-name>,
projects/<project>/zones/<zone>/reservations/<reservation-name>.

View File

@@ -0,0 +1,32 @@
project:
name: project
collection: tpu.projects
attributes:
- &project
parameter_name: projectsId
attribute_name: project
help: The project ID.
location:
name: location
collection: tpu.projects.locations
disable_auto_completers: false
attributes:
- &location
parameter_name: locationsId
attribute_name: zone
help: |
The compute/zone of the Cloud TPU.
If not specified, will use `default` compute/zone.
property: compute/zone
queuedresource:
name: queued_resource
collection: tpu.projects.locations.queuedResources
attributes:
- *project
- *location
- parameter_name: queuedResourcesId
attribute_name: queued_resource
help: The unqualified resource name.

View File

@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SSH/SCP utilities for Cloud TPU Queued Resource commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.core import log
import six
def ParseNodeFlag(node_flag, node_specs):
"""Parses the --node flag into a list of node_specs."""
num_nodes = len(node_specs)
if six.text_type(node_flag).upper() == 'ALL':
indexes = list(range(num_nodes))
else:
indexes = set()
ranges = node_flag.split(',')
for r in ranges:
if not r:
continue
if '-' in r:
bounds = r.split('-')
if len(bounds) != 2 or not bounds[0] or not bounds[1]:
raise exceptions.InvalidArgumentException(
'--node',
'Range "{}" does not match expected format'
' "lowerBound-upperBound", where lowerBound < upperBound.'.format(
r
),
)
start, end = int(bounds[0]), int(bounds[1])
if start >= end:
raise exceptions.InvalidArgumentException(
'--node',
'Range "{}" does not match expected format'
' "lowerBound-upperBound", where lowerBound < upperBound.'.format(
r
),
)
indexes.update(range(start, end + 1))
else:
try:
indexes.add(int(r))
except ValueError:
raise exceptions.InvalidArgumentException(
'--node',
'unable to parse node ID {}. Please only use numbers.'.format(r),
)
if not indexes:
raise exceptions.InvalidArgumentException(
'--node',
'Unable to parse node ranges from {}.'.format(node_flag),
)
mx = max(indexes)
if mx >= num_nodes:
raise exceptions.InvalidArgumentException(
'--node',
'node index {} is larger than the valid node indices on this TPU Queued'
' Resource. Please only use indexes in the range [0, {}], inclusive.'
.format(mx, num_nodes - 1),
)
# Get the filtered node specs.
filtered_node_specs = []
for node in indexes:
filtered_node_specs.append(node_specs[node])
return filtered_node_specs
def WaitForNodeBatchCompletion(ssh_threads, nodes):
"""Waits for the completion of batch, but does not block for failures.
Args:
ssh_threads: List of ssh threads.
nodes: List of SSH prepped nodes.
"""
for ssh_thread in ssh_threads:
ssh_thread.join()
for node in nodes:
if node:
log.status.Print('Finished preparing node {}.'.format(node.tpu_name))

View File

@@ -0,0 +1,444 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility file that contains helpers for Queued Resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
from googlecloudsdk.api_lib.compute import metadata_utils
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.core import exceptions as sdk_core_exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.util import times
import six
class BootDiskConfigurationError(sdk_core_exceptions.Error):
"""Error if the boot disk configuration is invalid."""
def GetMessagesModule(version):
return apis.GetMessagesModule('tpu', version)
# TODO(b/276933950) Consider merging this MergeMetadata with
# googlecloudsdk.command_lib.compute.tpus.tpu_vm.util.MergeMetadata by moving
# it to googlecloudsdk.command_lib.compute.tpus.util
def MergeMetadata(args, api_version):
"""Creates the metadata for the Node.
Based on googlecloudsdk.command_lib.compute.tpus.tpu_vm.util.MergeMetadata.
Args:
args: The gcloud args
api_version: The api version (e.g. v2 or v2alpha1)
Returns:
The constructed MetadataValue.
"""
metadata_dict = metadata_utils.ConstructMetadataDict(
args.metadata, args.metadata_from_file
)
tpu_messages = GetMessagesModule(api_version)
metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_dict):
metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(key=key, value=value)
)
return metadata
def CreateNodeSpec(api_version):
"""Creates the repeated structure nodeSpec from args."""
def Process(ref, args, request):
tpu_messages = GetMessagesModule(api_version)
if request.queuedResource is None:
request.queuedResource = tpu_messages.QueuedResource()
if request.queuedResource.tpu is None:
request.queuedResource.tpu = tpu_messages.Tpu()
if request.queuedResource.tpu.nodeSpec:
node_spec = request.queuedResource.tpu.nodeSpec[0]
else:
request.queuedResource.tpu.nodeSpec = []
node_spec = tpu_messages.NodeSpec()
node_spec.node = tpu_messages.Node()
node_spec.parent = ref.Parent().RelativeName()
if args.accelerator_type:
node_spec.node.acceleratorType = args.accelerator_type
node_spec.node.runtimeVersion = args.runtime_version
if args.data_disk:
node_spec.node.dataDisks = []
for data_disk in args.data_disk:
attached_disk = tpu_messages.AttachedDisk(
sourceDisk=data_disk.sourceDisk, mode=data_disk.mode
)
node_spec.node.dataDisks.append(attached_disk)
if args.description:
node_spec.node.description = args.description
if args.labels:
node_spec.node.labels = tpu_messages.Node.LabelsValue()
node_spec.node.labels.additionalProperties = []
for key, value in args.labels.items():
node_spec.node.labels.additionalProperties.append(
tpu_messages.Node.LabelsValue.AdditionalProperty(
key=key, value=value
)
)
if args.range:
node_spec.node.cidrBlock = args.range
if args.shielded_secure_boot:
node_spec.node.shieldedInstanceConfig = (
tpu_messages.ShieldedInstanceConfig(enableSecureBoot=True)
)
if api_version == 'v2alpha1' and args.autocheckpoint_enabled:
node_spec.node.autocheckpointEnabled = True
node_spec.node.networkConfig = tpu_messages.NetworkConfig()
node_spec.node.serviceAccount = tpu_messages.ServiceAccount()
if args.network:
node_spec.node.networkConfig.network = args.network
if args.subnetwork:
node_spec.node.networkConfig.subnetwork = args.subnetwork
if args.service_account:
node_spec.node.serviceAccount.email = args.service_account
if args.scopes:
node_spec.node.serviceAccount.scope = args.scopes
if args.tags:
node_spec.node.tags = args.tags
node_spec.node.networkConfig.enableExternalIps = not args.internal_ips
if args.boot_disk:
node_spec.node.bootDiskConfig = ParseBootDiskConfig(
args.boot_disk, api_version)
node_spec.node.metadata = MergeMetadata(args, api_version)
if args.node_prefix and not args.node_count:
raise exceptions.ConflictingArgumentsException(
'Can only specify --node-prefix if --node-count is also specified.'
)
if args.node_id:
node_spec.nodeId = args.node_id
elif args.node_count:
if api_version == 'v2alpha1':
node_spec.multiNodeParams = tpu_messages.MultiNodeParams()
node_spec.multiNodeParams.nodeCount = args.node_count
if args.node_prefix:
node_spec.multiNodeParams.nodeIdPrefix = args.node_prefix
if args.workload_type == 'AVAILABILITY_OPTIMIZED':
node_spec.multiNodeParams.workloadType = (
tpu_messages.MultiNodeParams.WorkloadTypeValueValuesEnum.AVAILABILITY_OPTIMIZED
)
else: # For v2 API, MultiNodeParams was renamed to MultisliceParams
node_spec.multisliceParams = tpu_messages.MultisliceParams()
node_spec.multisliceParams.nodeCount = args.node_count
if args.node_prefix:
node_spec.multisliceParams.nodeIdPrefix = args.node_prefix
request.queuedResource.tpu.nodeSpec = [node_spec]
return request
return Process
def ParseBootDiskConfig(
boot_disk_args, api_version='v2alpha1'
) -> GetMessagesModule('v2alpha1').BootDiskConfig:
"""Parses configurations for boot disk."""
tpu_messages = GetMessagesModule(api_version)
kms_key_arg_name = 'kms-key'
confidential_compute_arg_name = 'confidential-compute'
for arg_name in boot_disk_args.keys():
if arg_name not in [kms_key_arg_name, confidential_compute_arg_name]:
raise BootDiskConfigurationError(
'--boot-disk only supports arguments: {} and {}'.format(
confidential_compute_arg_name, kms_key_arg_name
)
)
enable_confidential_compute = (
boot_disk_args.get(confidential_compute_arg_name, 'False').lower()
== 'true'
)
kms_key = boot_disk_args.get(kms_key_arg_name, None)
if enable_confidential_compute:
if api_version != 'v2alpha1':
raise exceptions.InvalidArgumentException(
'--boot-disk',
'confidential-compute is only available in the alpha release track.')
if kms_key is None:
raise BootDiskConfigurationError(
'argument --boot-disk: with confidential-compute={} '
'requires kms-key; received: {}'.format(
enable_confidential_compute, kms_key)
)
boot_disk_config_kwargs = {}
if kms_key:
customer_encryption_key = tpu_messages.CustomerEncryptionKey(
kmsKeyName=kms_key)
boot_disk_config_kwargs['customerEncryptionKey'] = customer_encryption_key
if api_version == 'v2alpha1' and enable_confidential_compute:
boot_disk_config_kwargs['enableConfidentialCompute'] = (
enable_confidential_compute)
if boot_disk_config_kwargs:
return tpu_messages.BootDiskConfig(**boot_disk_config_kwargs)
return None
def VerifyNodeCount(ref, args, request):
del ref # unused
if args.node_count and args.node_count <= 1:
raise exceptions.InvalidArgumentException(
'--node-count', 'Node count must be greater than 1'
)
return request
def SetBestEffort(ref, args, request):
"""Creates an empty BestEffort structure if best-effort arg flag is set."""
del ref # unused
if args.best_effort:
tpu_messages = GetMessagesModule('v2alpha1')
if request.queuedResource is None:
request.queuedResource = tpu_messages.QueuedResource()
if request.queuedResource.bestEffort is None:
request.queuedResource.bestEffort = tpu_messages.BestEffort()
return request
def SetSpot(api_version):
"""Creates an empty Spot structure if spot flag is set."""
def Process(ref, args, request):
del ref # unused
tpu_messages = GetMessagesModule(api_version)
if request.queuedResource is None:
request.queuedResource = tpu_messages.QueuedResource()
if args.spot and request.queuedResource.spot is None:
request.queuedResource.spot = tpu_messages.Spot()
if api_version == 'v2alpha1' and args.provisioning_model:
provisioning_model = args.provisioning_model.replace('-', '_').upper()
spot = tpu_messages.QueuedResource.ProvisioningModelValueValuesEnum.SPOT
if provisioning_model == spot and request.queuedResource.spot is None:
request.queuedResource.spot = tpu_messages.Spot()
return request
return Process
def SetGuaranteed(api_version):
"""Creates an empty Guaranteed structure if arg flag is set."""
def Process(ref, args, request):
del ref # unused
if args.guaranteed:
tpu_messages = GetMessagesModule(api_version)
if request.queuedResource is None:
request.queuedResource = tpu_messages.QueuedResource()
if request.queuedResource.guaranteed is None:
request.queuedResource.guaranteed = tpu_messages.Guaranteed()
return request
return Process
def SetProvisioningModel(api_version):
"""Sets the provisioning model enum value."""
def Process(ref, args, request):
del ref # unused
tpu_messages = GetMessagesModule(api_version)
if request.queuedResource is None:
request.queuedResource = tpu_messages.QueuedResource()
if not args.provisioning_model:
if args.spot:
request.queuedResource.provisioningModel = (
tpu_messages.QueuedResource.ProvisioningModelValueValuesEnum.SPOT
)
request.queuedResource.provisioningModel = (
tpu_messages.QueuedResource.ProvisioningModelValueValuesEnum.STANDARD
)
return request
# create.yaml is in declarative style without direct access to the parser,
# instead leverage the generated client's enum functionality.
try:
# Per gcloud style guidance, standard choice flag options are lower
# case with dashes but we also support underscores and upper case.
normalized_candidate = args.provisioning_model.replace('-', '_').upper()
candidate_enum = (
tpu_messages.QueuedResource.ProvisioningModelValueValuesEnum(
normalized_candidate
)
)
except TypeError as e:
raise exceptions.InvalidArgumentException(
'--provisioning-model',
f'{args.provisioning_model} is not a valid provisioning model, must'
' be one of [standard, spot, reservation-bound, flex-start]',
) from e
request.queuedResource.provisioningModel = candidate_enum
return request
return Process
def SetValidInterval(api_version):
"""Combine multiple timing constraints into a valid_interval."""
def Process(ref, args, request):
del ref # unused
if (args.valid_after_duration and args.valid_after_time) or (
args.valid_until_duration and args.valid_until_time
):
raise exceptions.ConflictingArgumentsException(
'Only one timing constraint for each of (start, end) time is'
' permitted'
)
tpu_messages = GetMessagesModule(api_version)
current_time = times.Now()
start_time = None
if args.valid_after_time:
start_time = args.valid_after_time
elif args.valid_after_duration:
start_time = args.valid_after_duration.GetRelativeDateTime(current_time)
end_time = None
if args.valid_until_time:
end_time = args.valid_until_time
elif args.valid_until_duration:
end_time = args.valid_until_duration.GetRelativeDateTime(current_time)
if start_time and end_time:
valid_interval = tpu_messages.Interval()
valid_interval.startTime = times.FormatDateTime(start_time)
valid_interval.endTime = times.FormatDateTime(end_time)
if request.queuedResource is None:
request.queuedResource = tpu_messages.QueuedResource()
# clear all other queueing policies
request.queuedResource.queueingPolicy = tpu_messages.QueueingPolicy()
request.queuedResource.queueingPolicy.validInterval = valid_interval
return request
return Process
def CreateReservationName(api_version):
"""Creates the target reservation name from args.
Args:
api_version: The api version (e.g. v2 or v2alpha1)
Returns:
Handler which sets request.queuedResource.reservationName
"""
def Process(ref, args, request):
del ref # unused
if (
(args.reservation_host_project and args.reservation_host_folder)
or (args.reservation_host_folder and args.reservation_host_organization)
or (
args.reservation_host_organization and args.reservation_host_project
)
):
raise exceptions.ConflictingArgumentsException(
'Only one reservation host is permitted'
)
pattern = '{}/{}/locations/{}/reservations/-'
short_reservation_name_pattern = '^[a-zA-Z0-9-]+$'
full_reservation_name_pattern = 'projects/{}/locations/{}/reservations/{}'
reservation_name = None
if args.reservation_host_project:
reservation_name = pattern.format(
'projects', args.reservation_host_project, args.zone
)
elif args.reservation_host_folder:
reservation_name = pattern.format(
'folders', args.reservation_host_folder, args.zone
)
elif args.reservation_host_organization:
reservation_name = pattern.format(
'organizations', args.reservation_host_organization, args.zone
)
elif api_version == 'v2' and hasattr(args, 'reserved') and args.reserved:
project = properties.VALUES.core.project.GetOrFail()
reservation_name = pattern.format('projects', project, args.zone)
elif hasattr(args, 'reservation') and args.reservation and re.match(
short_reservation_name_pattern, args.reservation
):
project = properties.VALUES.core.project.GetOrFail()
reservation_name = full_reservation_name_pattern.format(
project, args.zone, args.reservation
)
if reservation_name:
request.queuedResource.reservationName = reservation_name
return request
return Process
def SetForce(ref, args, request):
"""Sets force arg to true if flag is set."""
del ref # unused
if hasattr(args, 'force') and args.force:
request.force = True
return request
class TPUQueuedResource(object):
"""Helper to get TPU Queued Resources."""
def __init__(self, release_track):
if release_track == base.ReleaseTrack.ALPHA:
self._api_version = 'v2alpha1'
else:
self._api_version = 'v2'
self.client = apis.GetClientInstance('tpu', self._api_version)
self.messages = apis.GetMessagesModule('tpu', self._api_version)
def GetMessages(self):
return self.messages
def Get(self, name, zone):
"""Retrieves the Queued Resource in the given project and zone."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_queued_resource_name_ref = resources.REGISTRY.Parse(
name,
params={'locationsId': zone, 'projectsId': project},
collection='tpu.projects.locations.queuedResources',
api_version=self._api_version,
)
request = self.messages.TpuProjectsLocationsQueuedResourcesGetRequest(
name=fully_qualified_queued_resource_name_ref.RelativeName()
)
return self.client.projects_locations_queuedResources.Get(request)

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,13 @@
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,32 @@
project:
name: project
collection: tpu.projects
attributes:
- &project
parameter_name: projectsId
attribute_name: project
help: Project ID.
location:
name: location
collection: tpu.projects.locations
disable_auto_completers: false
attributes:
- &location
parameter_name: locationsId
attribute_name: zone
help: |
Zone of the Cloud TPU.
If not specified, will use `default` compute/zone.
property: compute/zone
reservation:
name: reservation
collection: tpu.projects.locations.reservations
attributes:
- *project
- *location
- parameter_name: reservationId
attribute_name: reservation
help: The unqualified reservation ID.

View File

@@ -0,0 +1,48 @@
project:
name: project
collection: tpu.projects
attributes:
- parameter_name: projectsId
attribute_name: project
help: The project ID.
location:
name: location
collection: tpu.projects.locations
disable_auto_completers: false
attributes:
- &location
parameter_name: locationsId
attribute_name: zone
help: |
The compute/zone of the Cloud TPU.
If not specified, will use `default` compute/zone.
property: compute/zone
tpu:
name: tpu
collection: tpu.projects.locations.nodes
attributes:
- *location
- parameter_name: nodesId
attribute_name: tpu
help: The identifier of the Cloud TPU
tensorflowversion:
name: tensorflow_version
collection: tpu.projects.locations.tensorflowVersions
attributes:
- *location
- parameter_name: tensorflowVersionsId
attribute_name: version
help: The id of the Tensorflow version.
acceleratortype:
name: accelerator_type
collection: tpu.projects.locations.acceleratorTypes
attributes:
- *location
- parameter_name: acceleratorTypesId
attribute_name: accelerator_type
help: The id of the accelerator type.

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exceptions for Cloud TPU VM libraries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import exceptions
class TPUInUnusableTerminalState(exceptions.Error):
"""Error when the TPU is in an unusable, terminal state."""
def __init__(self, state):
super(TPUInUnusableTerminalState, self).__init__(
'This TPU has terminal state "{}", so it cannot be used anymore.'
.format(state))
class TPUInUnusableState(exceptions.Error):
"""Error when the TPU is in an unusable state."""
def __init__(self, state):
super(TPUInUnusableState, self).__init__(
'This TPU has state "{}", so it cannot be currently connected to.'
.format(state))
class SSHKeyNotInAgent(exceptions.Error):
"""Error when the SSH key is not in the SSH agent."""
def __init__(self, identity_file):
super(SSHKeyNotInAgent, self).__init__(
'SSH Key is not present in the SSH agent. Please run "ssh-add {}" to '
'add it, and try again.'.format(identity_file))
class IapTunnelingUnavailable(exceptions.Error):
"""Error when IAP tunneling is unavailable (either temporarily or not)."""
def __init__(self):
super(IapTunnelingUnavailable, self).__init__(
'Currently unable to connect to this TPU using IAP tunneling.')
class TPUInMaintenanceEvent(exceptions.Error):
"""Error when TPU has unhealthy maintenance for health."""
def __init__(self):
super(TPUInMaintenanceEvent, self).__init__(
'This TPU is going through a maintenance event, and is currently unavailable. For more information, see https://cloud.google.com/tpu/docs/maintenance-events.'
)

View File

@@ -0,0 +1,409 @@
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
accelerator_type:
api_field: node.acceleratorType
arg_name: accelerator-type
required: false
help_text: |
TPU accelerator type for the TPU.
If not specified, this defaults to `v2-8`.
For a list of available accelerator types run:
`{parent_command} accelerator-types list`
type: googlecloudsdk.command_lib.util.hooks.types:LowerCaseType
default: 'v2-8'
description:
api_field: node.description
arg_name: description
required: false
help_text: |
Text description of the TPU.
network:
api_field: node.networkConfig.network
arg_name: network
required: false
help_text: |
Network that this TPU will be a part of.
default: 'default'
subnetwork:
api_field: node.networkConfig.subnetwork
arg_name: subnetwork
required: false
help_text: |
Subnetwork that this TPU will be a part of.
enable_external_ips:
api_field: node.networkConfig.enableExternalIps
arg_name: internal-ips
required: false
default: false
type: bool
processor: googlecloudsdk.command_lib.compute.tpus.tpu_vm.util:InvertBoolean
help_text: |
Indicate that the IP addresses for the node should be internal. The default is that external IP
addresses will be associated with the TPU workers.
queue_count:
api_field: node.networkConfig.queueCount
arg_name: queue-count
type: int
required: false
help_text: |
Specifies the networking queue count for TPU VM instances. Both Rx and Tx queues will be set to
this number. If it's not specified, a default queue count will be assigned. For Virtio-net,
each interface will get min(floor(#vCPU / #vNIC), 32) queues. For gVNIC, each interface will
get min(floor(#vCPU / #vNIC / 2), 16) queues.
version:
api_field: node.runtimeVersion
arg_name: version
required: true
help_text: |
Runtime version for the TPU, such as `2.3`.
For a list of available versions run:
`{parent_command} versions list`
range: &range
api_field: node.cidrBlock
arg_name: range
help_text: |
CIDR Range for the TPU.
The IP range that the TPU will select an IP address from.
Must be in CIDR notation and a `/29` range, for example
`192.168.0.0/29`. Errors will occur if the CIDR range has already been
used for a currently existing TPU, the CIDR range conflicts with any
networks in the user's provided network, or the provided network is
peered with another network that is using that CIDR range.
preemptible:
api_field: node.schedulingConfig.preemptible
arg_name: preemptible
type: bool
default: false
required: false
help_text: |
If provided, the TPU will be preemptible and time-limited. It may be
preempted to free up resources for standard TPUs, and will only be able
to run for a limited amount of time.
Preemptible TPUs cannot be restarted.
spot:
api_field: node.schedulingConfig.spot
arg_name: spot
type: bool
default: false
required: false
help_text: |
If specified, create this VM as a spot VM.
Spot VMs make unused capacity available at highly discounted rates.
Spot VMs may be preempted at any time if the capacity is needed, but unless preempted there is
no limit on runtime duration. Spot VM TPUs cannot be restarted, and must be recreated again.
reserved:
api_field: node.schedulingConfig.reserved
arg_name: reserved
type: bool
default: false
required: false
help_text: |
When specified, will attempt to create the TPU node under reservations made in the current
project. The reservations can be made separately but used in aggregated form. i.e., the user can
make a reservation of 128 V2 TPUs and later on make another reservation of 128 V2 TPUs then
creates a v2-256 TPU instance. If there exists no reservation or not sufficient amount of
reserved cores under the project, the request will fail due to lack of capacity.
service_account:
api_field: node.serviceAccount.email
arg_name: service-account
required: false
help_text: |
Email address of the service account. If empty, default Google Compute Engine service
account will be used.
service_account_scopes:
api_field: node.serviceAccount.scope
arg_name: scopes
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
help_text: |
List of comma-separated scopes to be made available for the service account.
labels:
api_field: node.labels
arg_name: labels
metavar: KEY=VALUE
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
help_text: |
Resource labels to represent user-provided metadata. See
https://cloud.google.com/compute/docs/labeling-resources for details.
update_labels:
arg_name: update-labels
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
metavar: KEY=VALUE
help_text: |
Resource labels to update that represent user-provided metadata. If a label exists, its value is
modified. Otherwise, a new label is created.
See https://cloud.google.com/compute/docs/labeling-resources for details.
clear_labels:
arg_name: clear-labels
type: bool
default: false
required: false
help_text: |
Remove all labels. If `--update-labels` is also specified then `--clear-labels` is applied
first.
For example, to remove all labels:
$ {command} --clear-labels
To remove all existing labels and create two new labels, 'foo' and 'baz':
$ {command} --clear-labels --update-labels foo=bar,baz=qux
remove_labels:
arg_name: remove-labels
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
metavar: KEY
help_text: |
List of label keys to remove. If a label does not exist it is silently ignored. If
`--update-labels` is also specified then `--update-labels` is applied first.
tags:
api_field: node.tags
arg_name: tags
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
help_text: |
Tags to apply to the TPU Node. Tags are used to identify valid sources or
targets for network firewalls. See https://cloud.google.com/vpc/docs/add-remove-network-tags for
more details.
add_tags:
arg_name: add-tags
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
metavar: TAGS
help_text: |
Tags to add to the TPU Node. Tags are used to identify valid sources or
targets for network firewalls. See https://cloud.google.com/vpc/docs/add-remove-network-tags for
more details.
clear_tags:
arg_name: clear-tags
type: bool
default: false
required: false
help_text: |
Remove all tags. If `--add-tags` is also specified then `--clear-tags` is applied
first.
For example, to remove all tags:
$ {command} --clear-tags
To remove all existing tags and create two new tags, 'foo' and 'bar':
$ {command} --clear-tags --add-tags foo,bar
remove_tags:
arg_name: remove-tags
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
metavar: TAG
help_text: |
List of tags to remove. If a tag does not exist it is silently ignored. If
`--add-tags` is also specified then `--add-tags` is applied first.
data_disks:
api_field: node.dataDisks
arg_name: data-disk
type:
arg_dict:
flatten: false
spec:
- api_field: sourceDisk
arg_name: source
type: str
required: true
- api_field: mode
arg_name: mode
type: str
required: false
choices:
- arg_value: read-write
enum_value: READ_WRITE
- arg_value: read-only
enum_value: READ_ONLY
required: false
help_text: |
Additional data disks for the TPU VM.
This flag must be repeated to provide multiple data disks. For example:
$ {command} --data-disk source=projects/my-project/zones/us-central1-c/disks/my-disk,mode=read-only
The following keys are allowed:
*source*::: Specifies the full path to an existing disk. Required. The disk must be in the same zone.
*mode*::: Specifies the mode in which to attach this disk. Valid options are 'read-write',
'read-only'. If not specified, the default is 'read-write'.
attach_disk:
arg_name: attach-disk
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
required: false
metavar: SOURCE=DATA_DISK
help_text: |
Attach a data disk to the TPU VM. For example,
$ {command} example-tpu --attach-disk source=projects/my-project/zones/us-central1-c/disks/my-disk,mode=read-only --zone=us-central1-c
attaches the disk named 'projects/my-project/zones/us-central1-c/disks/my-disk' to a TPU VM named
'example-tpu' in read-only mode in zone `us-central1-c`.
The following keys are allowed:
*source*::: Specify the full path to an existing disk. Required. The disk must be in the same zone.
*mode*::: Specify the mode in which to attach this disk. Valid options are 'read-write',
'read-only'. If not specified, the default is 'read-write'.
detach_disk:
arg_name: detach-disk
type: str
required: false
metavar: DATA_DISK
help_text: |
Detach a data disk from the TPU VM. For example,
This flag must be repeated to provide multiple data disks. For example:
$ {command} example-tpu --detach-disk=projects/my-project/zones/us-central1-c/disks/my-disk --zone=us-central1-c
detaches the disk named 'projects/my-project/zones/us-central1-c/disks/my-disk' from a TPU VM
named 'example-tpu' in zone `us-central1-c`.
enable_secure_boot:
api_field: node.shieldedInstanceConfig.enableSecureBoot
arg_name: shielded-secure-boot
type: bool
default: false
required: false
help_text: |
Specifies that the TPU instances are created with secure boot enabled. This implicitly makes
them Shielded VM instances.
boot_disk:
arg_name: boot-disk
required: false
metavar: KEY=VALUE
type: "googlecloudsdk.calliope.arg_parsers:ArgDict:"
help_text: |
Specifies the boot disk configuration.
$ {command} \
--boot-disk kms-key=<full_kms_key_name_here>
The following keys are allowed:
*kms-key*::: Specifies the fully qualified Cloud KMS cryptokey name
which will be used to protect the disk. KMS cryptokey name format:
projects/<kms-project>/locations/<kms-location>/keyRings/<kms-keyring>/cryptoKeys/<key-name>
ALPHA:
help_text: |
Specifies the boot disk configuration.
$ {command} \
--boot-disk confidential-compute=True,kms-key=<full_kms_key_name_here>
The following keys are allowed:
*confidential-compute*::: Create the boot disk in confidential compute mode.
CMEK layer is required.
*kms-key*::: Specifies the fully qualified Cloud KMS cryptokey name
which will be used to protect the disk. KMS cryptokey name format:
projects/<kms-project>/locations/<kms-location>/keyRings/<kms-keyring>/cryptoKeys/<key-name>
autocheckpoint_enabled:
api_field: node.autocheckpointEnabled
arg_name: autocheckpoint-enabled
type: bool
default: false
required: false
help_text: |
If specified, the TPU node is created with the Autocheckpoint feature enabled.
reservation:
api_field: node.schedulingConfig.reservationName
arg_name: reservation
type: str
required: false
help_text: |
The name of the reservation. This can either be the full name or just the name.
However, a full name is required if the reservation is not in the default project.
Full reservation name format is:
projects/<project>/reservations/<reservation-name>
provisioning_model:
arg_name: provisioning-model
required: false
help_text: |
The provisioning model of the TPU node. One of [standard, spot, reservation-bound]. Setting
`--provisioning-model=spot` is functionally equivalent to setting `--spot`.
worker:
arg_name: worker
type: "googlecloudsdk.calliope.arg_parsers:ArgList:"
required: false
help_text: |
List of worker IDs to apply attach/detach disk to.
If not specified, the update is applied to all workers.
Only numeric values are allowed.
The only exception is 'all'.
`--attach-disk` and `--worker` can be used to specify which workers to attach the disk to. For example:
$ {command} example-tpu --attach-disk source=projects/my-project/zones/us-central1-c/disks/my-disk,mode=read-only --zone=us-central1-c
--worker=0,1
attaches the disk named 'projects/my-project/zones/us-central1-c/disks/my-disk' to worker 0 and 1 of a TPU VM
named 'example-tpu' in read-only mode in zone `us-central1-c`.
if `--worker` is not specified or `--worker=all` is specified, the disk is attached to all workers.
`--detach-disk` and `--worker` can be used to specify which workers to detach the disk from. For example:
$ {command} example-tpu --detach-disk=projects/my-project/zones/us-central1-c/disks/my-disk --zone=us-central1-c
--worker=0,1
detaches the disk named 'projects/my-project/zones/us-central1-c/disks/my-disk' from worker 0 and 1 of a TPU VM
named 'example-tpu' in zone `us-central1-c`.
if `--worker` is not specified or `--worker=all` is specified, the disk is attached to all workers.

View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared resource flags for Cloud SDK attach and detach disk command.
resource_args adds the TPU resource argument to
the attach-disk and detach-disk command.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope.concepts import concepts
def TPUAttributeConfig():
return concepts.ResourceParameterAttributeConfig(
name='TPU', help_text='The TPU Name for the {resource}.')
def ZoneAttributeConfig():
return concepts.ResourceParameterAttributeConfig(
name='zone', help_text='The Cloud zone for the {resource}.')
def GetTPUResourceSpec(resource_name='TPU'):
return concepts.ResourceSpec(
'tpu.projects.locations.nodes',
resource_name=resource_name,
locationsId=ZoneAttributeConfig(),
nodesId=TPUAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)

View File

@@ -0,0 +1,48 @@
project:
name: project
collection: tpu.projects
attributes:
- parameter_name: projectsId
attribute_name: project
help: Project ID.
location:
name: location
collection: tpu.projects.locations
disable_auto_completers: false
attributes:
- &location
parameter_name: locationsId
attribute_name: zone
help: |
Zone of the Cloud TPU.
If not specified, will use `default` compute/zone.
property: compute/zone
tpu:
name: tpu
collection: tpu.projects.locations.nodes
attributes:
- *location
- parameter_name: nodesId
attribute_name: tpu
help: Identifier of the Cloud TPU.
runtimeversion:
name: runtime_version
collection: tpu.projects.locations.runtimeVersions
attributes:
- *location
- parameter_name: runtimeVersionsId
attribute_name: version
help: ID of the runtime version.
acceleratortype:
name: accelerator_type
collection: tpu.projects.locations.acceleratorTypes
attributes:
- *location
- parameter_name: acceleratorTypesId
attribute_name: accelerator_type
help: ID of the accelerator type.

View File

@@ -0,0 +1,791 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CLI Utilities for Cloud TPU VM commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
import sys
from googlecloudsdk.api_lib.compute import base_classes
from googlecloudsdk.api_lib.compute import metadata_utils
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.args import map_util
from googlecloudsdk.core import exceptions as sdk_core_exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six
class NoFieldsSpecifiedError(sdk_core_exceptions.Error):
"""Error if no fields are specified for an update request."""
class AttachDiskError(sdk_core_exceptions.Error):
"""Error if the update request is invalid for attaching a disk."""
class DetachDiskError(sdk_core_exceptions.Error):
"""Error if the update request is invalid for detaching a disk."""
class BootDiskConfigurationError(sdk_core_exceptions.Error):
"""Error if the boot disk configuration is invalid."""
class WorkerIdsError(sdk_core_exceptions.Error):
"""Error if the boot disk configuration is invalid."""
def GetProject(release_track, ssh_helper):
holder = base_classes.ComputeApiHolder(release_track)
project_name = properties.VALUES.core.project.GetOrFail()
return ssh_helper.GetProject(holder.client, project_name)
def InvertBoolean(boolean):
"""Inverts the boolean value passed in."""
return not boolean
def MergeMetadata(api_version='v2'):
"""Request hook for merging the metadata and metadata from file."""
def Process(unused_ref, args, request):
"""Request hook for merging the metadata and metadata from file.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with metadata field populated.
"""
metadata_dict = metadata_utils.ConstructMetadataDict(
args.metadata, args.metadata_from_file)
tpu_messages = GetMessagesModule(version=api_version)
if request.node.metadata is None:
request.node.metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_dict):
request.node.metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
return request
return Process
def GetTagsUpdateFromArgs(args, tags):
"""Generate the change to the tags on a resource based on the arguments.
Args:
args: The args for this method.
tags: The current list of tags.
Returns:
The change to the tags after all of the arguments are applied.
"""
tags_update = tags
if args.IsKnownAndSpecified('clear_tags'):
tags_update = []
if args.IsKnownAndSpecified('add_tags'):
tags_update = sorted(set(tags_update + args.add_tags))
if args.IsKnownAndSpecified('remove_tags'):
tags_update = sorted(set(tags_update) - set(args.remove_tags))
return tags_update
def GenerateUpdateMask(api_version='v2'):
"""Request hook for constructing the updateMask for update requests."""
def Process(unused_ref, args, request):
"""Request hook for constructing the updateMask for update requests.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with updateMask field populated.
Raises:
NoFieldsSpecifiedError: if no fields were specified.
AttachDiskError: if the request for attaching a disk is invalid.
DetachDiskError: if the request for detaching a disk is invalid.
"""
update_mask = set()
tpu_messages = GetMessagesModule(version=api_version)
# Since it's possible that different API versions support different flags,
# we must check that the flag is both known in this version and if it is
# specified.
if args.IsKnownAndSpecified('description'):
update_mask.add('description')
if args.IsKnownAndSpecified('internal_ips'):
update_mask.add('network_config.enable_external_ips')
if (args.IsKnownAndSpecified('update_labels') or
args.IsKnownAndSpecified('remove_labels') or
args.IsKnownAndSpecified('clear_labels')):
labels_diff = labels_util.Diff.FromUpdateArgs(args)
if labels_diff.MayHaveUpdates():
labels_update = labels_diff.Apply(
tpu_messages.Node.LabelsValue,
request.node.labels)
if labels_update.needs_update:
request.node.labels = labels_update.labels
update_mask.add('labels')
if (args.IsKnownAndSpecified('add_tags') or
args.IsKnownAndSpecified('remove_tags') or
args.IsKnownAndSpecified('clear_tags')):
tags_update = GetTagsUpdateFromArgs(args, request.node.tags)
if set(tags_update) != set(request.node.tags):
request.node.tags = tags_update
update_mask.add('tags')
if args.IsKnownAndSpecified('metadata_from_file'):
metadata_dict = metadata_utils.ConstructMetadataDict(
None, args.metadata_from_file)
request.node.metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_dict):
request.node.metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
update_mask.add('metadata')
elif (args.IsKnownAndSpecified('update_metadata') or
args.IsKnownAndSpecified('remove_metadata') or
args.IsKnownAndSpecified('clear_metadata')):
metadata_dict = {}
if request.node.metadata is not None:
for item in request.node.metadata.additionalProperties:
metadata_dict[item.key] = item.value
# Apply flags one by one since we allow multiple flags to be set at once.
# The order should match the flags' descriptions.
metadata_update = map_util.ApplyMapFlags(metadata_dict, None,
None, args.clear_metadata, None,
None)
metadata_update = map_util.ApplyMapFlags(metadata_update, None,
args.update_metadata, None, None,
None)
metadata_update = map_util.ApplyMapFlags(metadata_update, None, None,
None, args.remove_metadata, None)
request.node.metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_update):
request.node.metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
update_mask.add('metadata')
if args.IsKnownAndSpecified('attach_disk'):
# validates worker
if not args.IsKnownAndSpecified('worker'):
args.worker = ['all']
is_all_workers_specified = ValidateWorkerIdsField(args)
if is_all_workers_specified:
args.worker = []
mode, source = '', ''
for key in args.attach_disk.keys():
if key == 'mode':
mode = args.attach_disk['mode']
elif key == 'source':
source = args.attach_disk['source']
else:
raise AttachDiskError(
'argument --attach-disk: valid keys are [mode, source]; '
'received: ' + key
)
if mode == 'read-only':
mode_enum = tpu_messages.AttachedDisk.ModeValueValuesEnum.READ_ONLY
elif not mode or mode == 'read-write':
mode_enum = tpu_messages.AttachedDisk.ModeValueValuesEnum.READ_WRITE
if len(args.worker) > 1:
raise AttachDiskError(
'argument --attach-disk: can only attach disks in read-write'
' to at most one worker; received: ' + str(args.worker)
)
else:
raise AttachDiskError(
'argument --attach-disk: key mode: can only attach disks in '
'read-write or read-only mode; received: ' + mode
)
# worker is de-duped and sorted.
worker = set(args.worker)
disk_to_attach = tpu_messages.AttachedDisk(
mode=mode_enum,
sourceDisk=source,
)
if api_version == 'v2alpha1':
disk_to_attach.workerIds = sorted(worker)
PreprocessDiskToAttach(request.node.dataDisks, disk_to_attach)
request.node.dataDisks.append(disk_to_attach)
update_mask.add('data_disks')
elif args.IsKnownAndSpecified('detach_disk'):
# validates worker
if not args.IsKnownAndSpecified('worker'):
args.worker = ['all']
is_all_workers_specified = ValidateWorkerIdsField(args)
if is_all_workers_specified:
args.worker = []
if not request.node.dataDisks:
raise DetachDiskError(
'argument --detach-disk: No data disks to detach from current TPU '
'VM.'
)
source_disk_list = []
for disk in request.node.dataDisks:
source_disk_list.append(disk.sourceDisk)
for i, source_disk in enumerate(source_disk_list):
if args.detach_disk != source_disk:
continue
if is_all_workers_specified:
del request.node.dataDisks[i]
break
worker_diff = set(
request.node.dataDisks[i].workerIds) - set(args.worker)
if not worker_diff:
del request.node.dataDisks[i]
break
request.node.dataDisks[i].workerIds = sorted(worker_diff)
break
else:
raise DetachDiskError(
'argument --detach-disk: The specified data disk '
+ args.detach_disk + ' is not currently attached to the TPU VM.'
)
update_mask.add('data_disks')
if not update_mask:
raise NoFieldsSpecifiedError(
'No fields would change as a result of this update; must specify at '
'least one field to update.')
request.updateMask = ','.join(update_mask)
return request
return Process
def RemoveConflictingDefaults(unused_ref, args, request):
"""Unset acceleratorType flag when it conflicts with topology arguments.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with metadata field populated.
"""
if args.topology is not None:
request.node.acceleratorType = None
return request
def GetMessagesModule(version='v2'):
return apis.GetMessagesModule('tpu', version)
def StartRequestHook(api_version='v2'):
"""Declarative request hook for TPU Start command."""
def Process(ref, args, request):
del ref
del args
start_request = GetMessagesModule(version=api_version).StartNodeRequest()
request.startNodeRequest = start_request
return request
return Process
def StopRequestHook(api_version='v2'):
"""Declarative request hook for TPU Stop command."""
def Process(ref, args, request):
del ref
del args
stop_request = GetMessagesModule(version=api_version).StopNodeRequest()
request.stopNodeRequest = stop_request
return request
return Process
def IsTPUVMNode(node):
api_version = six.text_type(node.apiVersion).upper()
return (not api_version.startswith('V1')
and api_version != 'API_VERSION_UNSPECIFIED')
def FilterTPUVMNodes(response, args):
"""Removes Cloud TPU V1 API nodes from the 'list' output.
Used with 'compute tpus tpu-vm list'.
Args:
response: response to ListNodes.
args: the arguments for the list command.
Returns:
A response with only TPU VM (non-V1 API) nodes.
"""
del args
return list(six.moves.filter(IsTPUVMNode, response))
class GuestAttributesListEntry(object):
"""Holder for GetGuestAttributes output."""
def __init__(self, worker_id, namespace, key, value):
self.worker_id = worker_id
self.namespace = namespace
self.key = key
self.value = value
def TransformGuestAttributes(response, args):
"""Transforms the GuestAttributes into a flatter list.
This is needed to make clearer output in the case of TPU pods, since they have
many workers.
Args:
response: response to GetGuestAttributes.
args: the arguments for the GetGuestAttributes command.
Returns:
A list of GuestAttributesListEntry objects.
"""
del args
lst = []
for i, ga in enumerate(response.guestAttributes):
for entry in ga.queryValue.items:
lst.append(
GuestAttributesListEntry(i, entry.namespace, entry.key, entry.value))
return lst
def PreprocessDiskToAttach(current_data_disks_list, disk_to_attach):
"""Preprocesses and validates the disk to attach.
Validates the disk to attach is not already attached to the TPU VM with
different mode or same mode and worker.
Deletes the disk from the current_data_disks_list if it is already attached
to the TPU VM with same mode but different worker.
If the disk is currently attached to the TPU VM with same mode,
joins the current worker list and the new worker list.
Args:
current_data_disks_list: the list of data disks currently attached to the
TPU VM.
disk_to_attach: the disk to attach to the TPU VM.
Raises:
AttachDiskError: if the disk is already attached to the TPU VM
with different mode.
AttachDiskError: if the disk is already attached to the TPU VM with same
mode and worker.
"""
for i, disk in enumerate(current_data_disks_list):
if disk.sourceDisk != disk_to_attach.sourceDisk:
continue
if (disk.mode != disk_to_attach.mode):
raise AttachDiskError(
'argument --attach-disk: the disk is already attached to the TPU '
'VM with different mode.'
)
if not (set(disk_to_attach.workerIds) - set(disk.workerIds)):
raise AttachDiskError(
'argument --attach-disk: the disk is already attached to '
'the same set of workers of TPU VM.'
)
disk_to_attach.workerIds = sorted(
set(disk.workerIds + disk_to_attach.workerIds))
# To avoid disk with same name appear twice in the list.
del current_data_disks_list[i]
def ValidateWorkerIdsField(args):
"""Checks that the worker are numberic strings only.
The only exception is "all" which is a special value that means all
workers. If "all" is specified return True.
Args:
args: the arguments for the update command.
Returns:
True if only one string "all" is specified in args.worker
False otherwise.
Raises:
WorkerIdsError: if the worker are not numberic strings only.
"""
if len(args.worker) == 1 and args.worker[0] == 'all':
return True
for w in args.worker:
if w == 'all' and len(args.worker) > 1:
raise WorkerIdsError(
'argument --worker',
'"all" cannot be specified with other worker.',
)
if not w.isnumeric():
raise WorkerIdsError(
'argument --worker',
'worker must be numeric strings only or '
'"all". e.g. --worker=0,1,2 or --worker=all',
)
return False
def CheckTPUVMNode(response, args):
"""Verifies that the node is a TPU VM node.
If it is not a TPU VM node, exit with an error instead.
Args:
response: response to GetNode.
args: the arguments for the list command.
Returns:
The response to GetNode if the node is TPU VM.
"""
del args
if IsTPUVMNode(response):
return response
log.err.Print('ERROR: Please use "gcloud compute tpus describe" for Cloud TPU'
' nodes that are not TPU VM.')
sys.exit(1)
def ParseBootDiskConfigurations(api_version='v2'):
"""Request hook for parsing boot disk configurations."""
def Process(unused_ref, args, request):
"""Parses configurations for boot disk."""
if not args or not args.IsKnownAndSpecified('boot_disk'):
return request
kms_key_arg_name = 'kms-key'
confidential_compute_arg_name = 'confidential-compute'
for arg_name in args.boot_disk.keys():
if arg_name not in [kms_key_arg_name, confidential_compute_arg_name]:
raise BootDiskConfigurationError(
'--boot-disk only supports arguments: {} and {}'.format(
confidential_compute_arg_name, kms_key_arg_name
)
)
tpu_messages = GetMessagesModule(version=api_version)
enable_confidential_compute = (
args.boot_disk.get(confidential_compute_arg_name, 'False').lower()
== 'true'
)
kms_key = args.boot_disk.get(kms_key_arg_name, None)
if enable_confidential_compute:
if api_version != 'v2alpha1':
raise exceptions.InvalidArgumentException(
'--boot-disk',
'confidential-compute is only available in the alpha release track.'
)
if kms_key is None:
raise BootDiskConfigurationError(
'argument --boot-disk: with confidential-compute={} '
'requires kms-key; received: {}'.format(
enable_confidential_compute, kms_key)
)
boot_disk_config_kwargs = {}
if kms_key:
customer_encryption_key = tpu_messages.CustomerEncryptionKey(
kmsKeyName=kms_key)
boot_disk_config_kwargs['customerEncryptionKey'] = customer_encryption_key
if api_version == 'v2alpha1' and enable_confidential_compute:
boot_disk_config_kwargs['enableConfidentialCompute'] = (
enable_confidential_compute)
if boot_disk_config_kwargs:
request.node.bootDiskConfig = tpu_messages.BootDiskConfig(
**boot_disk_config_kwargs)
return request
return Process
def SetImage(api_version='v2alpha1'):
"""Request hook for setting the source machine image."""
def Process(unused_ref, args, request):
"""Sets the source machine image in the request if provided."""
if args.IsSpecified('image'):
tpu_messages = GetMessagesModule(version=api_version)
if not request.node.bootDiskConfig:
request.node.bootDiskConfig = tpu_messages.BootDiskConfig()
request.node.bootDiskConfig.sourceImage = args.image
return request
return Process
def ProjectIdToProjectNumber(project_id):
"""Returns the Cloud project number associated with the `project_id`."""
crm_message_module = apis.GetMessagesModule('cloudresourcemanager', 'v1')
resource_manager = apis.GetClientInstance('cloudresourcemanager', 'v1')
req = crm_message_module.CloudresourcemanagerProjectsGetRequest(
projectId=project_id)
project = resource_manager.projects.Get(req)
return project.projectNumber
def CreateReservationName(unused_ref, args, request):
"""Request hook for creating the target reservation name.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with reservationName field populated.
"""
short_reservation_name_pattern = '^[a-zA-Z0-9-]+$'
full_reservation_name_pattern = 'projects/{}/locations/{}/reservations/{}'
reservation_name = None
if args.IsKnownAndSpecified('reservation') and re.match(
short_reservation_name_pattern, args.reservation
):
project_id = properties.VALUES.core.project.GetOrFail()
project_number = ProjectIdToProjectNumber(project_id)
reservation_name = full_reservation_name_pattern.format(
project_number, args.zone, args.reservation
)
if reservation_name:
request.node.schedulingConfig.reservationName = reservation_name
return request
def SetProvisioningModel(api_version):
"""Sets the provisioning model enum value."""
def Process(_, args, request):
tpu_messages = GetMessagesModule(api_version)
if args.spot:
request.node.schedulingConfig.provisioningModel = (
tpu_messages.SchedulingConfig.ProvisioningModelValueValuesEnum.SPOT
)
return request
if not args.provisioning_model:
request.node.schedulingConfig.provisioningModel = (
tpu_messages.SchedulingConfig.ProvisioningModelValueValuesEnum.STANDARD
)
return request
try:
normalized_candidate = args.provisioning_model.replace('-', '_').upper()
candidate_enum = (
tpu_messages.SchedulingConfig.ProvisioningModelValueValuesEnum(
normalized_candidate
)
)
except TypeError as e:
raise exceptions.InvalidArgumentException(
'--provisioning-model',
f'{args.provisioning_model} is not a valid provisioning model, must'
' be one of [standard, spot, reservation-bound]',
) from e
request.node.schedulingConfig.provisioningModel = candidate_enum
return request
return Process
class TPUNode(object):
"""Helper to create and modify TPU nodes."""
def __init__(self, release_track):
if release_track == base.ReleaseTrack.ALPHA:
self._api_version = 'v2alpha1'
else:
self._api_version = 'v2'
self.client = apis.GetClientInstance('tpu', self._api_version)
self.messages = apis.GetMessagesModule('tpu', self._api_version)
def GetMessages(self):
return self.messages
def Get(self, name, zone):
"""Retrieves the TPU node in the given zone."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
request = self.messages.TpuProjectsLocationsNodesGetRequest(
name=fully_qualified_node_name_ref.RelativeName())
return self.client.projects_locations_nodes.Get(request)
def GetGuestAttributes(self, name, zone, worker_id=''):
"""Retrives the Guest Attributes for the nodes."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
get_guest_attributes_request = self.messages.GetGuestAttributesRequest(
workerIds=[worker_id])
request = self.messages.TpuProjectsLocationsNodesGetGuestAttributesRequest(
name=fully_qualified_node_name_ref.RelativeName(),
getGuestAttributesRequest=get_guest_attributes_request)
return self.client.projects_locations_nodes.GetGuestAttributes(request)
def UpdateNode(self, name, zone, node, update_mask, poller_message):
"""Updates the TPU node in the given zone."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
request = self.messages.TpuProjectsLocationsNodesPatchRequest(
name=fully_qualified_node_name_ref.RelativeName(),
node=node,
updateMask=update_mask)
# Call UpdateNode to start the LRO.
operation = self.client.projects_locations_nodes.Patch(request)
operation_ref = resources.REGISTRY.ParseRelativeName(
operation.name, collection='tpu.projects.locations.operations'
)
# Wait for the UpdateNode LRO to complete.
return self.WaitForOperation(operation_ref, poller_message)
def UpdateMetadataKey(self, metadata, key, value):
"""Updates a key in the TPU metadata object.
If the key does not exist, it is added.
Args:
metadata: tpu.messages.Node.MetadataValue, the TPU's metadata.
key: str, the key to be updated.
value: str, the new value for the key.
Returns:
The updated metadata.
"""
# If the metadata is empty, return a new metadata object with just the key.
if metadata is None or metadata.additionalProperties is None:
return self.messages.Node.MetadataValue(
additionalProperties=[
self.messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value)])
item = None
for x in metadata.additionalProperties:
if x.key == key:
item = x
break
if item is not None:
item.value = value
else:
# The key is not in the metadata, so append it.
metadata.additionalProperties.append(
self.messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
return metadata
def WaitForOperation(self, operation_ref, message):
operation_poller = waiter.CloudOperationPoller(
self.client.projects_locations_nodes,
self.client.projects_locations_operations)
return waiter.WaitFor(operation_poller, operation_ref, message)
class SSHPreppedNode(object):
"""Object that has all the data needed to successfully SSH into a node.
Attributes:
worker_ips: The IPs of the workers of the node.
ssh_helper: The ssh_helper used to SSH into the node.
id: The id of the node.
tpu_name: The unqualified TPU VM name.
instance_names: The name of the instances of the workers of the node.
project: The project associated with the node.
command_list: The list of the commands passed into ssh.
remainder: The remainder list of ssh_args used to pass into the SSH command.
host_key_suffixes: The host key suffixes associated with the node.
user: The user executing the SSH command.
release_track: The release track for the SSH protos (Alpha, Beta, etc.).
enable_batching: A bool indicating if the user enabled batching for the
node.
"""
def __init__(self, tpu_name, user, release_track, enable_batching):
self.tpu_name = tpu_name
self.user = user
self.release_track = release_track
self.enable_batching = enable_batching
self.worker_ips = []
self.ssh_helper = None
self.id = None
self.instance_names = []
self.project = None
self.command_list = []
self.remainder = None
self.host_key_suffixes = []
class SCPPreppedNode(SSHPreppedNode):
"""Object that has all the data needed to successfully SCP into a node.
Attributes:
srcs: The sources for SCP.
dst: The destination for SCP.
"""
def __init__(self, tpu_name, user, release_track, enable_batching, srcs, dst):
super(SCPPreppedNode, self).__init__(
tpu_name, user, release_track, enable_batching
)
self.srcs = srcs
self.dst = dst

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CLI Utilities for cloud tpu commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
def ListTopologiesResponseHook(response, args):
"""Reformat to extract topologies and sort by acceleratorType."""
del args
results = []
for accelerator_type in response:
for accelerator_config in accelerator_type.acceleratorConfigs:
results += [{
'topology': accelerator_config.topology,
'type': accelerator_config.type,
'acceleratorType': accelerator_type.type
}]
results.sort(key=lambda x: (int(x['acceleratorType'].split('-')[1])))
return results