471 lines
18 KiB
Python
471 lines
18 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2020 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""notebooks instances api helper."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
import enum
|
|
from googlecloudsdk.api_lib.notebooks import environments as env_util
|
|
from googlecloudsdk.api_lib.notebooks import util
|
|
from googlecloudsdk.command_lib.util.apis import arg_utils
|
|
from googlecloudsdk.core import log
|
|
from googlecloudsdk.core import resources
|
|
|
|
_RESERVATION_AFFINITY_KEY = 'compute.googleapis.com/reservation-name'
|
|
|
|
|
|
def CreateInstance(args, client, messages):
|
|
"""Creates the Instance message for the create request.
|
|
|
|
Args:
|
|
args: Argparse object from Command.Run
|
|
client(base_api.BaseApiClient): An instance of the specified API client.
|
|
messages: Module containing messages definition for the specified API.
|
|
|
|
Returns:
|
|
Instance of the Instance message.
|
|
"""
|
|
|
|
def GetContainerImageFromExistingEnvironment():
|
|
environment_service = client.projects_locations_environments
|
|
result = environment_service.Get(
|
|
env_util.CreateEnvironmentDescribeRequest(args, messages))
|
|
return result.containerImage
|
|
|
|
def GetVmImageFromExistingEnvironment():
|
|
environment_service = client.projects_locations_environments
|
|
result = environment_service.Get(
|
|
env_util.CreateEnvironmentDescribeRequest(args, messages))
|
|
return result.vmImage
|
|
|
|
def GetKmsRelativeName():
|
|
if args.IsSpecified('kms_key'):
|
|
return args.CONCEPTS.kms_key.Parse().RelativeName()
|
|
|
|
def GetNetworkRelativeName():
|
|
if args.IsSpecified('network'):
|
|
return args.CONCEPTS.network.Parse().RelativeName()
|
|
|
|
def GetSubnetRelativeName():
|
|
if args.IsSpecified('subnet'):
|
|
return args.CONCEPTS.subnet.Parse().RelativeName()
|
|
|
|
def CreateAcceleratorConfigMessage():
|
|
accelerator_config = messages.AcceleratorConfig
|
|
type_enum = None
|
|
if args.IsSpecified('accelerator_type'):
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='accelerator-type',
|
|
message_enum=accelerator_config.TypeValueValuesEnum,
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x).GetEnumForChoice(
|
|
arg_utils.EnumNameToChoice(args.accelerator_type))
|
|
return accelerator_config(
|
|
type=type_enum, coreCount=args.accelerator_core_count)
|
|
|
|
def GetBootDisk():
|
|
type_enum = None
|
|
if args.IsSpecified('boot_disk_type'):
|
|
instance_message = messages.Instance
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='boot-disk-type',
|
|
message_enum=instance_message.BootDiskTypeValueValuesEnum,
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x).GetEnumForChoice(
|
|
arg_utils.EnumNameToChoice(args.boot_disk_type))
|
|
return type_enum
|
|
|
|
def GetDataDisk():
|
|
type_enum = None
|
|
if args.IsSpecified('data_disk_type'):
|
|
instance_message = messages.Instance
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='data-disk-type',
|
|
message_enum=instance_message.DataDiskTypeValueValuesEnum,
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x,
|
|
).GetEnumForChoice(arg_utils.EnumNameToChoice(args.data_disk_type))
|
|
return type_enum
|
|
|
|
def GetDiskEncryption():
|
|
type_enum = None
|
|
if args.IsSpecified('disk_encryption'):
|
|
instance_message = messages.Instance
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='disk-encryption',
|
|
message_enum=instance_message.DiskEncryptionValueValuesEnum,
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x).GetEnumForChoice(
|
|
arg_utils.EnumNameToChoice(args.disk_encryption))
|
|
return type_enum
|
|
|
|
def CreateContainerImageFromArgs():
|
|
if args.IsSpecified('environment'):
|
|
return GetContainerImageFromExistingEnvironment()
|
|
if args.IsSpecified('container_repository'):
|
|
container_image = messages.ContainerImage(
|
|
repository=args.container_repository, tag=args.container_tag)
|
|
return container_image
|
|
return None
|
|
|
|
def CreateVmImageFromArgs():
|
|
"""Create VmImage Message from an environment or from args."""
|
|
if args.IsSpecified('environment'):
|
|
return GetVmImageFromExistingEnvironment()
|
|
if args.IsSpecified('container_repository'):
|
|
return None
|
|
vm_image = messages.VmImage(project=args.vm_image_project)
|
|
if args.IsSpecified('vm_image_name'):
|
|
vm_image.imageName = args.vm_image_name
|
|
else:
|
|
vm_image.imageFamily = args.vm_image_family
|
|
return vm_image
|
|
|
|
def GetInstanceOwnersFromArgs():
|
|
if args.IsSpecified('instance_owners'):
|
|
return [args.instance_owners]
|
|
return []
|
|
|
|
def GetLabelsFromArgs():
|
|
if args.IsSpecified('labels'):
|
|
labels_message = messages.Instance.LabelsValue
|
|
return labels_message(additionalProperties=[
|
|
labels_message.AdditionalProperty(key=key, value=value)
|
|
for key, value in args.labels.items()
|
|
])
|
|
return None
|
|
|
|
def GetMetadataFromArgs():
|
|
if args.IsSpecified('metadata'):
|
|
metadata_message = messages.Instance.MetadataValue
|
|
return metadata_message(additionalProperties=[
|
|
metadata_message.AdditionalProperty(key=key, value=value)
|
|
for key, value in args.metadata.items()
|
|
])
|
|
return None
|
|
|
|
def GetShieldedInstanceConfigFromArgs():
|
|
if not (args.IsSpecified('shielded_vm_secure_boot') or
|
|
args.IsSpecified('shielded_vm_vtpm') or
|
|
args.IsSpecified('shielded_vm_integrity_monitoring')):
|
|
return None
|
|
shielded_instance_config_message = messages.ShieldedInstanceConfig
|
|
return shielded_instance_config_message(
|
|
enableIntegrityMonitoring=args.shielded_vm_integrity_monitoring,
|
|
enableSecureBoot=args.shielded_vm_secure_boot,
|
|
enableVtpm=args.shielded_vm_vtpm,
|
|
)
|
|
|
|
def GetTagsFromArgs():
|
|
if args.IsSpecified('tags'):
|
|
return args.tags
|
|
return []
|
|
|
|
# TODO(b/194714138): Consider adding validation for specific reservation.
|
|
def GetReservationAffinityConfigFromArgs():
|
|
if not (args.IsSpecified('reservation_affinity') or
|
|
args.IsSpecified('reservation')):
|
|
return None
|
|
|
|
def GetReservationAffinityEnum():
|
|
type_enum = None
|
|
if args.IsSpecified('reservation_affinity'):
|
|
reservation_affinity_message = messages.ReservationAffinity
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='reservation-affinity',
|
|
message_enum=(reservation_affinity_message
|
|
.ConsumeReservationTypeValueValuesEnum),
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x).GetEnumForChoice(
|
|
arg_utils.EnumNameToChoice(args.reservation_affinity))
|
|
return type_enum
|
|
|
|
reservation_affinity_enum = GetReservationAffinityEnum()
|
|
reservation_key = None
|
|
reservation_values = []
|
|
|
|
if (reservation_affinity_enum == messages.ReservationAffinity
|
|
.ConsumeReservationTypeValueValuesEnum.SPECIFIC_RESERVATION):
|
|
# Currently, the key is fixed and the value is the name of the
|
|
# reservation.
|
|
# The value being a repeated field is reserved for future use when user
|
|
# can specify more than one reservation names from which the VM can take
|
|
# capacity from.
|
|
reservation_key = _RESERVATION_AFFINITY_KEY
|
|
reservation_values = [args.reservation]
|
|
|
|
reservation_config_message = messages.ReservationAffinity
|
|
return reservation_config_message(
|
|
consumeReservationType=reservation_affinity_enum,
|
|
key=reservation_key,
|
|
values=reservation_values,
|
|
)
|
|
|
|
instance = messages.Instance(
|
|
name=args.instance,
|
|
postStartupScript=args.post_startup_script,
|
|
customGpuDriverPath=args.custom_gpu_driver_path,
|
|
instanceOwners=GetInstanceOwnersFromArgs(),
|
|
kmsKey=GetKmsRelativeName(),
|
|
machineType=args.machine_type,
|
|
network=GetNetworkRelativeName(),
|
|
noProxyAccess=args.no_proxy_access,
|
|
noPublicIp=args.no_public_ip,
|
|
serviceAccount=args.service_account,
|
|
subnet=GetSubnetRelativeName(),
|
|
vmImage=CreateVmImageFromArgs(),
|
|
acceleratorConfig=CreateAcceleratorConfigMessage(),
|
|
bootDiskType=GetBootDisk(),
|
|
bootDiskSizeGb=args.boot_disk_size,
|
|
dataDiskType=GetDataDisk(),
|
|
dataDiskSizeGb=args.data_disk_size,
|
|
noRemoveDataDisk=args.no_remove_data_disk,
|
|
containerImage=CreateContainerImageFromArgs(),
|
|
diskEncryption=GetDiskEncryption(),
|
|
labels=GetLabelsFromArgs(),
|
|
metadata=GetMetadataFromArgs(),
|
|
installGpuDriver=args.install_gpu_driver,
|
|
shieldedInstanceConfig=GetShieldedInstanceConfigFromArgs(),
|
|
reservationAffinity=GetReservationAffinityConfigFromArgs(),
|
|
tags=GetTagsFromArgs(),
|
|
)
|
|
return instance
|
|
|
|
|
|
def CreateInstanceCreateRequest(args, client, messages):
|
|
parent = util.GetParentForInstance(args)
|
|
instance = CreateInstance(args, client, messages)
|
|
return messages.NotebooksProjectsLocationsInstancesCreateRequest(
|
|
parent=parent, instance=instance, instanceId=args.instance)
|
|
|
|
|
|
def CreateInstanceListRequest(args, messages):
|
|
parent = util.GetParentFromArgs(args)
|
|
return messages.NotebooksProjectsLocationsInstancesListRequest(parent=parent)
|
|
|
|
|
|
def CreateInstanceDeleteRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
return messages.NotebooksProjectsLocationsInstancesDeleteRequest(
|
|
name=instance)
|
|
|
|
|
|
def CreateInstanceDescribeRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
return messages.NotebooksProjectsLocationsInstancesGetRequest(name=instance)
|
|
|
|
|
|
def CreateInstanceRegisterRequest(args, messages):
|
|
instance = GetInstanceResource(args)
|
|
parent = util.GetLocationResource(instance.locationsId,
|
|
instance.projectsId).RelativeName()
|
|
register_request = messages.RegisterInstanceRequest(
|
|
instanceId=instance.Name())
|
|
return messages.NotebooksProjectsLocationsInstancesRegisterRequest(
|
|
parent=parent, registerInstanceRequest=register_request)
|
|
|
|
|
|
def CreateInstanceResetRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
reset_request = messages.ResetInstanceRequest()
|
|
return messages.NotebooksProjectsLocationsInstancesResetRequest(
|
|
name=instance, resetInstanceRequest=reset_request)
|
|
|
|
|
|
def CreateInstanceStartRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
start_request = messages.StartInstanceRequest()
|
|
return messages.NotebooksProjectsLocationsInstancesStartRequest(
|
|
name=instance, startInstanceRequest=start_request)
|
|
|
|
|
|
def CreateInstanceStopRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
stop_request = messages.StopInstanceRequest()
|
|
return messages.NotebooksProjectsLocationsInstancesStopRequest(
|
|
name=instance, stopInstanceRequest=stop_request)
|
|
|
|
|
|
def CreateSetAcceleratorRequest(args, messages):
|
|
"""Create and return Accelerator update request."""
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
set_acc_request = messages.SetInstanceAcceleratorRequest()
|
|
accelerator_config = messages.SetInstanceAcceleratorRequest
|
|
if args.IsSpecified('accelerator_core_count'):
|
|
set_acc_request.coreCount = args.accelerator_core_count
|
|
if args.IsSpecified('accelerator_type'):
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='accelerator-type',
|
|
message_enum=accelerator_config.TypeValueValuesEnum,
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x).GetEnumForChoice(
|
|
arg_utils.EnumNameToChoice(args.accelerator_type))
|
|
set_acc_request.type = type_enum
|
|
return messages.NotebooksProjectsLocationsInstancesSetAcceleratorRequest(
|
|
name=instance, setInstanceAcceleratorRequest=set_acc_request)
|
|
|
|
|
|
def CreateSetLabelsRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
set_label_request = messages.SetInstanceLabelsRequest()
|
|
labels_message = messages.SetInstanceLabelsRequest.LabelsValue
|
|
set_label_request.labels = labels_message(additionalProperties=[
|
|
labels_message.AdditionalProperty(key=key, value=value)
|
|
for key, value in args.labels.items()
|
|
])
|
|
return messages.NotebooksProjectsLocationsInstancesSetLabelsRequest(
|
|
name=instance, setInstanceLabelsRequest=set_label_request)
|
|
|
|
|
|
def CreateSetMachineTypeRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
set_machine_request = messages.SetInstanceMachineTypeRequest(
|
|
machineType=args.machine_type)
|
|
return messages.NotebooksProjectsLocationsInstancesSetMachineTypeRequest(
|
|
name=instance, setInstanceMachineTypeRequest=set_machine_request)
|
|
|
|
|
|
def CreateInstanceGetHealthRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
return messages.NotebooksProjectsLocationsInstancesGetInstanceHealthRequest(
|
|
name=instance)
|
|
|
|
|
|
def CreateInstanceIsUpgradeableRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
return messages.NotebooksProjectsLocationsInstancesIsUpgradeableRequest(
|
|
notebookInstance=instance)
|
|
|
|
|
|
def CreateInstanceUpgradeRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
upgrade_request = messages.UpgradeInstanceRequest()
|
|
return messages.NotebooksProjectsLocationsInstancesUpgradeRequest(
|
|
name=instance, upgradeInstanceRequest=upgrade_request)
|
|
|
|
|
|
def CreateInstanceRollbackRequest(args, messages):
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
rollback_request = messages.RollbackInstanceRequest(
|
|
targetSnapshot=args.target_snapshot)
|
|
return messages.NotebooksProjectsLocationsInstancesRollbackRequest(
|
|
name=instance, rollbackInstanceRequest=rollback_request)
|
|
|
|
|
|
def CreateInstanceDiagnoseRequest(args, messages):
|
|
""""Create and return Diagnose request."""
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
diagnostic_config = messages.DiagnosticConfig(
|
|
gcsBucket=args.gcs_bucket,
|
|
)
|
|
if args.IsSpecified('relative_path'):
|
|
diagnostic_config.relativePath = args.relative_path
|
|
if args.IsSpecified('enable-repair'):
|
|
diagnostic_config.repairFlagEnabled = True
|
|
if args.IsSpecified('enable-packet-capture'):
|
|
diagnostic_config.packetCaptureFlagEnabled = True
|
|
if args.IsSpecified('enable-copy-home-files'):
|
|
diagnostic_config.copyHomeFilesFlagEnabled = True
|
|
|
|
timeout_minutes = None
|
|
if args.IsSpecified('timeout_minutes'):
|
|
timeout_minutes = int(args.timeout_minutes)
|
|
|
|
return messages.NotebooksProjectsLocationsInstancesDiagnoseRequest(
|
|
name=instance, diagnoseInstanceRequest=messages.DiagnoseInstanceRequest(
|
|
diagnosticConfig=diagnostic_config, timeoutMinutes=timeout_minutes))
|
|
|
|
|
|
def CreateInstanceMigrateRequest(args, messages):
|
|
""""Create and return Migrate request."""
|
|
instance = GetInstanceResource(args).RelativeName()
|
|
|
|
def GetPostStartupScriptOption():
|
|
type_enum = None
|
|
if args.IsSpecified('post_startup_script_option'):
|
|
request_message = messages.MigrateInstanceRequest
|
|
type_enum = arg_utils.ChoiceEnumMapper(
|
|
arg_name='post-startup-script-option',
|
|
message_enum=request_message.PostStartupScriptOptionValueValuesEnum,
|
|
include_filter=lambda x: 'UNSPECIFIED' not in x).GetEnumForChoice(
|
|
arg_utils.EnumNameToChoice(args.post_startup_script_option))
|
|
return type_enum
|
|
|
|
return messages.NotebooksProjectsLocationsInstancesMigrateRequest(
|
|
name=instance,
|
|
migrateInstanceRequest=messages.MigrateInstanceRequest(
|
|
postStartupScriptOption=GetPostStartupScriptOption(),
|
|
))
|
|
|
|
|
|
def GetInstanceResource(args):
|
|
return args.CONCEPTS.instance.Parse()
|
|
|
|
|
|
def GetInstanceURI(resource):
|
|
instance = resources.REGISTRY.ParseRelativeName(
|
|
resource.name, collection='notebooks.projects.locations.instances')
|
|
return instance.SelfLink()
|
|
|
|
|
|
class OperationType(enum.Enum):
|
|
CREATE = (log.CreatedResource, 'created')
|
|
UPDATE = (log.UpdatedResource, 'updated')
|
|
UPGRADE = (log.UpdatedResource, 'upgraded')
|
|
ROLLBACK = (log.UpdatedResource, 'rolled back')
|
|
DELETE = (log.DeletedResource, 'deleted')
|
|
RESET = (log.ResetResource, 'reset')
|
|
MIGRATE = (log.UpdatedResource, 'migrated')
|
|
|
|
|
|
def HandleLRO(operation,
|
|
args,
|
|
instance_service,
|
|
release_track,
|
|
operation_type=OperationType.UPDATE):
|
|
"""Handles Long Running Operations for both cases of async.
|
|
|
|
Args:
|
|
operation: The operation to poll.
|
|
args: ArgParse instance containing user entered arguments.
|
|
instance_service: The service to get the resource after the long running
|
|
operation completes.
|
|
release_track: base.ReleaseTrack object.
|
|
operation_type: Enum value of type OperationType indicating the kind of
|
|
operation to wait for.
|
|
|
|
Raises:
|
|
apitools.base.py.HttpError: if the request returns an HTTP error
|
|
|
|
Returns:
|
|
The Instance resource if synchronous, else the Operation Resource.
|
|
"""
|
|
logging_method = operation_type.value[0]
|
|
if args.async_:
|
|
logging_method(
|
|
util.GetOperationResource(operation.name, release_track),
|
|
kind='notebooks instance {0}'.format(args.instance),
|
|
is_async=True)
|
|
return operation
|
|
else:
|
|
response = util.WaitForOperation(
|
|
operation,
|
|
'Waiting for operation on Instance [{}] to be {} with [{}]'.format(
|
|
args.instance, operation_type.value[1], operation.name),
|
|
service=instance_service,
|
|
release_track=release_track,
|
|
is_delete=(operation_type.value[1] == 'deleted'))
|
|
logging_method(
|
|
util.GetOperationResource(operation.name, release_track),
|
|
kind='notebooks instance {0}'.format(args.instance),
|
|
is_async=False)
|
|
return response
|