feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,246 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used for AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
BETA_VERSION = 'BETA'
GA_VERSION = 'GA'
AI_PLATFORM_API_VERSION = {
GA_VERSION: 'v1',
BETA_VERSION: 'v1beta1'
}
AI_PLATFORM_MESSAGE_PREFIX = {
GA_VERSION: 'GoogleCloudAiplatformV1',
BETA_VERSION: 'GoogleCloudAiplatformV1beta1'
}
AI_PLATFORM_API_NAME = 'aiplatform'
AI_PLATFORM_RELEASE_TRACK_TO_VERSION = {
base.ReleaseTrack.GA: GA_VERSION,
base.ReleaseTrack.BETA: BETA_VERSION
}
# TODO(b/448146624): To remove hardcoded regions in non-default universes,
# by calling location service to get available regions in non-default universes.
NON_DEFAULT_UNIVERSE_REGIONS = ('u-us-prp1',)
# NOTE: The following region lists are GDU-specific.
# They should ONLY be used as fallbacks within the Google Default Universe
# when dynamic region lookups fail. Do not use for validation or prompting
# in non GDU environments.
# The default available regions for most Vertex AI products. See
# https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability
# for more details.
SUPPORTED_REGION = ('us-central1', 'europe-west4', 'asia-east1')
# NOTE: GDU-specific, see comment above SUPPORTED_REGION.
# Available regions specifically for training, including custom-jobs and
# hp-tuning-jobs.
SUPPORTED_TRAINING_REGIONS = (
'africa-south1',
'asia-east1',
'asia-east2',
'asia-northeast1',
'asia-northeast2',
'asia-northeast3',
'asia-south1',
'asia-south2',
'asia-southeast1',
'asia-southeast2',
'australia-southeast1',
'australia-southeast2',
'europe-central2',
'europe-north1',
'europe-southwest1',
'europe-west1',
'europe-west2',
'europe-west3',
'europe-west4',
'europe-west6',
'europe-west8',
'europe-west9',
'europe-west12',
'me-central1',
'me-central2',
'me-west1',
'northamerica-northeast1',
'northamerica-northeast2',
'southamerica-east1',
'southamerica-west1',
'us-central1',
'us-east1',
'us-east4',
'us-east5',
'us-south1',
'us-west1',
'us-west2',
'us-west3',
'us-west4',
)
# NOTE: GDU-specific, see comment above SUPPORTED_REGION.
# Available regions specifically for online prediction, including endpoints and
# models
SUPPORTED_OP_REGIONS = (
'africa-south1',
'asia-east1',
'asia-east2',
'asia-northeast1',
'asia-northeast2',
'asia-northeast3',
'asia-south1',
'asia-south2',
'asia-southeast1',
'asia-southeast2',
'australia-southeast1',
'australia-southeast2',
'europe-central2',
'europe-north1',
'europe-southwest1',
'europe-west1',
'europe-west2',
'europe-west3',
'europe-west4',
'europe-west6',
'europe-west8',
'europe-west9',
'europe-west12',
'me-central1',
'me-central2',
'me-west1',
'northamerica-northeast1',
'northamerica-northeast2',
'southamerica-east1',
'southamerica-west1',
'us-central1',
'us-east1',
'us-east4',
'us-east5',
'us-south1',
'us-west1',
'us-west2',
'us-west3',
'us-west4',
)
# NOTE: GDU-specific, see comment above SUPPORTED_REGION.
# Available regions specifically for deployment resource pools
SUPPORTED_DEPLOYMENT_RESOURCE_POOL_REGIONS = (
'us-central1',
'us-east1',
'us-east4',
'us-west1',
'europe-west1',
'asia-northeast1',
'asia-southeast1',
)
# NOTE: GDU-specific, see comment above SUPPORTED_REGION.
# Available regions specifically for model monitoring jobs
SUPPORTED_MODEL_MONITORING_JOBS_REGIONS = (
'asia-east1',
'asia-east2',
'asia-northeast1',
'asia-northeast3',
'asia-south1',
'asia-southeast1',
'asia-southeast2',
'australia-southeast1',
'europe-central2',
'europe-west1',
'europe-west2',
'europe-west3',
'europe-west4',
'europe-west6',
'europe-west9',
'northamerica-northeast1',
'northamerica-northeast2',
'southamerica-east1',
'us-central1',
'us-east1',
'us-east4',
'us-west1',
'us-west2',
'us-west3',
'us-west4',
)
OPERATION_CREATION_DISPLAY_MESSAGE = """\
The {verb} operation [{name}] was submitted successfully.
You may view the status of your operation with the command
$ gcloud ai operations describe {id} {sub_commands}\
"""
DEFAULT_OPERATION_COLLECTION = 'aiplatform.projects.locations.operations'
DEPLOYMENT_RESOURCE_POOLS_COLLECTION = 'aiplatform.projects.locations.deploymentResourcePools'
ENDPOINTS_COLLECTION = 'aiplatform.projects.locations.endpoints'
INDEX_ENDPOINTS_COLLECTION = 'aiplatform.projects.locations.indexEndpoints'
INDEXES_COLLECTION = 'aiplatform.projects.locations.indexes'
TENSORBOARDS_COLLECTION = 'aiplatform.projects.locations.tensorboards'
TENSORBOARD_EXPERIMENTS_COLLECTION = 'aiplatform.projects.locations.tensorboards.experiments'
TENSORBOARD_RUNS_COLLECTION = 'aiplatform.projects.locations.tensorboards.experiments.runs'
TENSORBOARD_TIME_SERIES_COLLECTION = 'aiplatform.projects.locations.tensorboards.experiments.runs.timeSeries'
MODEL_MONITORING_JOBS_COLLECTION = 'aiplatform.projects.locations.modelDeploymentMonitoringJobs'
# gcloud-disable-gdu-domain
OP_AUTOSCALING_METRIC_NAME_MAPPER = {
'cpu-usage': 'aiplatform.googleapis.com/prediction/online/cpu/utilization',
'gpu-duty-cycle': (
'aiplatform.googleapis.com/prediction/online/accelerator/duty_cycle'
),
'request-counts-per-minute': (
'aiplatform.googleapis.com/prediction/online/request_count'
),
}
MODEL_MONITORING_JOB_CREATION_DISPLAY_MESSAGE = """\
Model monitoring Job [{id}] submitted successfully.
Your job is still active. You may view the status of your job with the command
$ {cmd_prefix} ai model-monitoring-jobs describe {id}
Job State: {state}\
"""
MODEL_MONITORING_JOB_PAUSE_DISPLAY_MESSAGE = """\
Request to pause model deployment monitoring job [{id}] has been sent
You may view the status of your job with the command
$ {cmd_prefix} ai model-monitoring-jobs describe {id}
"""
MODEL_MONITORING_JOB_RESUME_DISPLAY_MESSAGE = """\
Request to resume model deployment monitoring job [{id}] has been sent
You may view the status of your job with the command
$ {cmd_prefix} ai model-monitoring-jobs describe {id}
"""

View File

@@ -0,0 +1,253 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform custom jobs commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.ai.custom_jobs import local_util
from googlecloudsdk.command_lib.ai.docker import build as docker_build
from googlecloudsdk.command_lib.ai.docker import utils as docker_utils
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.core import log
from googlecloudsdk.core.util import files
# TODO(b/191347326): Consider adding tests for the "public" methods in this file
CUSTOM_JOB_COLLECTION = 'aiplatform.projects.locations.customJobs'
def _ConstructSingleWorkerPoolSpec(aiplatform_client,
spec,
python_package_uri=None):
"""Constructs the specification of a single worker pool.
Args:
aiplatform_client: The AI Platform API client used.
spec: A dict whose fields represent a worker pool config.
python_package_uri: str, The common python package uris that will be used by
executor image, supposedly derived from the gcloud command flags.
Returns:
A WorkerPoolSpec message instance for setting a worker pool in a custom job.
"""
worker_pool_spec = aiplatform_client.GetMessage('WorkerPoolSpec')()
machine_spec_msg = aiplatform_client.GetMessage('MachineSpec')
machine_spec = machine_spec_msg(machineType=spec.get('machine-type'))
accelerator_type = spec.get('accelerator-type')
if accelerator_type:
machine_spec.acceleratorType = arg_utils.ChoiceToEnum(
accelerator_type, machine_spec_msg.AcceleratorTypeValueValuesEnum)
machine_spec.acceleratorCount = int(spec.get('accelerator-count', 1))
worker_pool_spec.machineSpec = machine_spec
worker_pool_spec.replicaCount = int(spec.get('replica-count', 1))
container_image_uri = spec.get('container-image-uri')
executor_image_uri = spec.get('executor-image-uri')
python_module = spec.get('python-module')
if container_image_uri:
container_spec_msg = aiplatform_client.GetMessage('ContainerSpec')
worker_pool_spec.containerSpec = container_spec_msg(
imageUri=container_image_uri)
elif python_package_uri or executor_image_uri or python_module:
python_package_spec_msg = aiplatform_client.GetMessage('PythonPackageSpec')
worker_pool_spec.pythonPackageSpec = python_package_spec_msg(
executorImageUri=executor_image_uri,
packageUris=(python_package_uri or []),
pythonModule=python_module)
return worker_pool_spec
def _ConstructWorkerPoolSpecs(aiplatform_client, specs, **kwargs):
"""Constructs the specification of the worker pools in a CustomJobSpec instance.
Args:
aiplatform_client: The AI Platform API client used.
specs: A list of dict of worker pool specifications, supposedly derived from
the gcloud command flags.
**kwargs: The keyword args to pass down to construct each worker pool spec.
Returns:
A list of WorkerPoolSpec message instances for creating a custom job.
"""
worker_pool_specs = []
for spec in specs:
if spec:
worker_pool_specs.append(
_ConstructSingleWorkerPoolSpec(aiplatform_client, spec, **kwargs))
else:
worker_pool_specs.append(aiplatform_client.GetMessage('WorkerPoolSpec')())
return worker_pool_specs
def _PrepareTrainingImage(project,
job_name,
base_image,
local_package,
script,
output_image_name,
python_module=None,
**kwargs):
"""Build a training image from local package and push it to Cloud for later usage."""
output_image = output_image_name or docker_utils.GenerateImageName(
base_name=job_name, project=project, is_gcr=True)
docker_build.BuildImage(
base_image=base_image,
host_workdir=files.ExpandHomeDir(local_package),
main_script=script,
python_module=python_module,
output_image_name=output_image,
**kwargs)
log.status.Print('\nA custom container image is built locally.\n')
push_command = ['docker', 'push', output_image]
docker_utils.ExecuteDockerCommand(push_command)
log.status.Print(
'\nCustom container image [{}] is created for your custom job.\n'.format(
output_image))
return output_image
def UpdateWorkerPoolSpecsIfLocalPackageRequired(worker_pool_specs, job_name,
project):
"""Update the given worker pool specifications if any contains local packages.
If any given worker pool spec is specified a local package, this builds
a Docker image from the local package and update the spec to use it.
Args:
worker_pool_specs: list of dict representing the arg value specified via the
`--worker-pool-spec` flag.
job_name: str, the display name of the custom job corresponding to the
worker pool specs.
project: str, id of the project to which the custom job is submitted.
Yields:
All updated worker pool specifications that uses the already built
packages and are expectedly passed to a custom-jobs create RPC request.
"""
image_built_for_first_worker = None
if worker_pool_specs and 'local-package-path' in worker_pool_specs[0]:
base_image = worker_pool_specs[0].pop('executor-image-uri')
local_package = worker_pool_specs[0].pop('local-package-path')
python_module = worker_pool_specs[0].pop('python-module', None)
if python_module:
script = local_util.ModuleToPath(python_module)
else:
script = worker_pool_specs[0].pop('script')
output_image = worker_pool_specs[0].pop('output-image-uri', None)
image_built_for_first_worker = _PrepareTrainingImage(
project=project,
job_name=job_name,
base_image=base_image,
local_package=local_package,
script=script,
output_image_name=output_image,
python_module=python_module,
requirements=worker_pool_specs[0].pop('requirements', None),
extra_packages=worker_pool_specs[0].pop('extra-packages', None),
extra_dirs=worker_pool_specs[0].pop('extra-dirs', None))
for spec in worker_pool_specs:
if image_built_for_first_worker and spec:
new_spec = spec.copy()
new_spec['container-image-uri'] = image_built_for_first_worker
yield new_spec
else:
yield spec
def ConstructCustomJobSpec(
aiplatform_client,
base_config=None,
network=None,
service_account=None,
enable_web_access=None,
enable_dashboard_access=None,
worker_pool_specs=None,
args=None,
command=None,
persistent_resource_id=None,
**kwargs):
"""Constructs the spec of a custom job to be used in job creation request.
Args:
aiplatform_client: The AI Platform API client used.
base_config: A base CustomJobSpec message instance, e.g. imported from a
YAML config file, as a template to be overridden.
network: user network to which the job should be peered with (overrides yaml
file)
service_account: A service account (email address string) to use for the
job.
enable_web_access: Whether to enable the interactive shell for the job.
enable_dashboard_access: Whether to enable the access to the dashboard built
on the job.
worker_pool_specs: A dict of worker pool specification, usually derived from
the gcloud command argument values.
args: A list of arguments to be passed to containers or python packge,
supposedly derived from the gcloud command flags.
command: A list of commands to be passed to containers, supposedly derived
from the gcloud command flags.
persistent_resource_id: The name of the persistent resource from the same
project and region on which to run this custom job.
**kwargs: The keyword args to pass to construct the worker pool specs.
Returns:
A CustomJobSpec message instance for creating a custom job.
"""
job_spec = base_config
if network is not None:
job_spec.network = network
if service_account is not None:
job_spec.serviceAccount = service_account
if enable_web_access:
job_spec.enableWebAccess = enable_web_access
if enable_dashboard_access:
job_spec.enableDashboardAccess = enable_dashboard_access
if worker_pool_specs:
job_spec.workerPoolSpecs = _ConstructWorkerPoolSpecs(
aiplatform_client, worker_pool_specs, **kwargs)
if args:
for worker_pool_spec in job_spec.workerPoolSpecs:
if worker_pool_spec.containerSpec:
worker_pool_spec.containerSpec.args = args
if worker_pool_spec.pythonPackageSpec:
worker_pool_spec.pythonPackageSpec.args = args
if command:
for worker_pool_spec in job_spec.workerPoolSpecs:
if worker_pool_spec.containerSpec:
worker_pool_spec.containerSpec.command = command
if persistent_resource_id:
job_spec.persistentResourceId = persistent_resource_id
return job_spec
def _IsKwargsDefined(key, **kwargs):
return key in kwargs and bool(kwargs.get(key))

View File

@@ -0,0 +1,410 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags definition specifically for gcloud ai custom-jobs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import argparse
import textwrap
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope.concepts import concepts
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags as shared_flags
from googlecloudsdk.command_lib.ai import region_util
from googlecloudsdk.command_lib.ai.custom_jobs import custom_jobs_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.concepts import concept_parsers
_DISPLAY_NAME = base.Argument(
'--display-name',
required=True,
help=('Display name of the custom job to create.'))
_PYTHON_PACKAGE_URIS = base.Argument(
'--python-package-uris',
metavar='PYTHON_PACKAGE_URIS',
type=arg_parsers.ArgList(),
help=('The common Python package URIs to be used for training with a '
'pre-built container image. e.g. `--python-package-uri=path1,path2` '
'If you are using multiple worker pools and want to specify a '
'different Python packag fo reach pool, use `--config` instead.'))
_CUSTOM_JOB_CONFIG = base.Argument(
'--config',
help=textwrap.dedent("""\
Path to the job configuration file. This file should be a YAML document
containing a [`CustomJobSpec`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec).
If an option is specified both in the configuration file **and** via command-line arguments, the command-line arguments
override the configuration file. Note that keys with underscore are invalid.
Example(YAML):
workerPoolSpecs:
machineSpec:
machineType: n1-highmem-2
replicaCount: 1
containerSpec:
imageUri: gcr.io/ucaip-test/ucaip-training-test
args:
- port=8500
command:
- start"""))
_WORKER_POOL_SPEC = base.Argument(
'--worker-pool-spec',
action='append',
type=arg_parsers.ArgDict(
spec={
'replica-count': int,
'machine-type': str,
'accelerator-type': str,
'accelerator-count': int,
'container-image-uri': str,
'executor-image-uri': str,
'output-image-uri': str,
'python-module': str,
'script': str,
'local-package-path': str,
'requirements': arg_parsers.ArgList(custom_delim_char=';'),
'extra-dirs': arg_parsers.ArgList(custom_delim_char=';'),
'extra-packages': arg_parsers.ArgList(custom_delim_char=';'),
}),
metavar='WORKER_POOL_SPEC',
help=textwrap.dedent("""\
Define the worker pool configuration used by the custom job. You can
specify multiple worker pool specs in order to create a custom job with
multiple worker pools.
The spec can contain the following fields:
*machine-type*::: (Required): The type of the machine.
see https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
for supported types. This is corresponding to the `machineSpec.machineType`
field in `WorkerPoolSpec` API message.
*replica-count*::: The number of worker replicas to use for this worker
pool, by default the value is 1. This is corresponding to the `replicaCount`
field in `WorkerPoolSpec` API message.
*accelerator-type*::: The type of GPUs.
see https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
for more requirements. This is corresponding to the `machineSpec.acceleratorType`
field in `WorkerPoolSpec` API message.
*accelerator-count*::: The number of GPUs for each VM in the worker pool to
use, by default the value if 1. This is corresponding to the
`machineSpec.acceleratorCount` field in `WorkerPoolSpec` API message.
*container-image-uri*::: The URI of a container image to be directly run on
each worker replica. This is corresponding to the
`containerSpec.imageUri` field in `WorkerPoolSpec` API message.
*executor-image-uri*::: The URI of a container image that will run the
provided package.
*output-image-uri*::: The URI of a custom container image to be built for
autopackaged custom jobs.
*python-module*::: The Python module name to run within the provided
package.
*local-package-path*::: The local path of a folder that contains training
code.
*script*::: The relative path under the `local-package-path` to a file to
execute. It can be a Python file or an arbitrary bash script.
*requirements*::: Python dependencies to be installed from PyPI, separated
by ";". This is supposed to be used when some public packages are
required by your training application but not in the base images.
It has the same effect as editing a "requirements.txt" file under
`local-package-path`.
*extra-packages*::: Relative paths of local Python archives to be installed,
separated by ";". This is supposed to be used when some custom packages
are required by your training application but not in the base images.
Every path should be relative to the `local-package-path`.
*extra-dirs*::: Relative paths of the folders under `local-package-path`
to be copied into the container, separated by ";". If not specified, only
the parent directory that contains the main executable (`script` or
`python-module`) will be copied.
::::
Note that some of these fields are used for different job creation methods
and are categorized as mutually exclusive groups listed below. Exactly one of
these groups of fields must be specified:
`container-image-uri`::::
Specify this field to use a custom container image for training. Together
with the `--command` and `--args` flags, this field represents a
[`WorkerPoolSpec.ContainerSpec`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec?#containerspec)
message.
In this case, the `--python-package-uris` flag is disallowed.
Example:
--worker-pool-spec=replica-count=1,machine-type=n1-highmem-2,container-image-uri=gcr.io/ucaip-test/ucaip-training-test
`executor-image-uri, python-module`::::
Specify these fields to train using a pre-built container and Python
packages that are already in Cloud Storage. Together with the
`--python-package-uris` and `--args` flags, these fields represent a
[`WorkerPoolSpec.PythonPackageSpec`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#pythonpackagespec)
message .
Example:
--worker-pool-spec=machine-type=e2-standard-4,executor-image-uri=us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest,python-module=trainer.task
`output-image-uri`::::
Specify this field to push the output custom container training image to a specific path in Container Registry or Artifact Registry for an autopackaged custom job.
Example:
--worker-pool-spec=machine-type=e2-standard-4,executor-image-uri=us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest,output-image-uri='eu.gcr.io/projectName/imageName',python-module=trainer.task
`local-package-path, executor-image-uri, output-image-uri, python-module|script`::::
Specify these fields, optionally with `requirements`, `extra-packages`, or
`extra-dirs`, to train using a pre-built container and Python code from a
local path.
In this case, the `--python-package-uris` flag is disallowed.
Example using `python-module`:
--worker-pool-spec=machine-type=e2-standard-4,replica-count=1,executor-image-uri=us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest,python-module=trainer.task,local-package-path=/usr/page/application
Example using `script`:
--worker-pool-spec=machine-type=e2-standard-4,replica-count=1,executor-image-uri=us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest,script=my_run.sh,local-package-path=/usr/jeff/application
"""))
_CUSTOM_JOB_COMMAND = base.Argument(
'--command',
type=arg_parsers.ArgList(),
metavar='COMMAND',
action=arg_parsers.UpdateAction,
help="""\
Command to be invoked when containers are started.
It overrides the entrypoint instruction in Dockerfile when provided.
""")
_CUSTOM_JOB_ARGS = base.Argument(
'--args',
metavar='ARG',
type=arg_parsers.ArgList(),
action=arg_parsers.UpdateAction,
help='Comma-separated arguments passed to containers or python tasks.')
_PERSISTENT_RESOURCE_ID = base.Argument(
'--persistent-resource-id',
metavar='PERSISTENT_RESOURCE_ID',
help="""\
The name of the persistent resource from the same project and region on
which to run this custom job.
If this is specified, the job will be run on existing machines held by the
PersistentResource instead of on-demand short-lived machines.
The network and CMEK configs on the job should be consistent with those on
the PersistentResource, otherwise, the job will be rejected.
""")
def AddCreateCustomJobFlags(parser):
"""Adds flags related to create a custom job."""
shared_flags.AddRegionResourceArg(
parser,
'to create a custom job',
prompt_func=region_util.GetPromptForRegionFunc(
constants.SUPPORTED_TRAINING_REGIONS))
shared_flags.TRAINING_SERVICE_ACCOUNT.AddToParser(parser)
shared_flags.NETWORK.AddToParser(parser)
shared_flags.ENABLE_WEB_ACCESS.AddToParser(parser)
shared_flags.ENABLE_DASHBOARD_ACCESS.AddToParser(parser)
shared_flags.AddKmsKeyResourceArg(parser, 'custom job')
labels_util.AddCreateLabelsFlags(parser)
_DISPLAY_NAME.AddToParser(parser)
_PYTHON_PACKAGE_URIS.AddToParser(parser)
_CUSTOM_JOB_ARGS.AddToParser(parser)
_CUSTOM_JOB_COMMAND.AddToParser(parser)
_PERSISTENT_RESOURCE_ID.AddToParser(parser)
worker_pool_spec_group = base.ArgumentGroup(
help='Worker pool specification.', required=True)
worker_pool_spec_group.AddArgument(_CUSTOM_JOB_CONFIG)
worker_pool_spec_group.AddArgument(_WORKER_POOL_SPEC)
worker_pool_spec_group.AddToParser(parser)
def AddCustomJobResourceArg(parser,
verb,
regions=constants.SUPPORTED_TRAINING_REGIONS):
"""Add a resource argument for a Vertex AI custom job.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the parser for the command.
verb: str, the verb to describe the job resource, such as 'to update'.
regions: list[str], the list of supported regions.
"""
job_resource_spec = concepts.ResourceSpec(
resource_collection=custom_jobs_util.CUSTOM_JOB_COLLECTION,
resource_name='custom job',
locationsId=shared_flags.RegionAttributeConfig(
prompt_func=region_util.GetPromptForRegionFunc(regions)),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
disable_auto_completers=False)
concept_parsers.ConceptParser.ForResource(
'custom_job',
job_resource_spec,
'The custom job {}.'.format(verb),
required=True).AddToParser(parser)
def AddLocalRunCustomJobFlags(parser):
"""Add local-run related flags to the parser."""
# Flags for entry point of the training application
application_group = parser.add_mutually_exclusive_group()
application_group.add_argument(
'--python-module',
metavar='PYTHON_MODULE',
help=textwrap.dedent("""
Name of the python module to execute, in 'trainer.train' or 'train'
format. Its path should be relative to the `work_dir`.
"""))
application_group.add_argument(
'--script',
metavar='SCRIPT',
help=textwrap.dedent("""
The relative path of the file to execute. Accepets a Python file or an
arbitrary bash script. This path should be relative to the `work_dir`.
"""))
# Flags for working directory.
parser.add_argument(
'--local-package-path',
metavar='LOCAL_PATH',
suggestion_aliases=['--work-dir'],
help=textwrap.dedent("""
local path of the directory where the python-module or script exists.
If not specified, it use the directory where you run the this command.
Only the contents of this directory will be accessible to the built
container image.
"""))
# Flags for extra directory
parser.add_argument(
'--extra-dirs',
metavar='EXTRA_DIR',
type=arg_parsers.ArgList(),
help=textwrap.dedent("""
Extra directories under the working directory to include, besides the one
that contains the main executable.
By default, only the parent directory of the main script or python module
is copied to the container.
For example, if the module is "training.task" or the script is
"training/task.py", the whole "training" directory, including its
sub-directories, will always be copied to the container. You may specify
this flag to also copy other directories if necessary.
Note: if no parent is specified in 'python_module' or 'scirpt', the whole
working directory is copied, then you don't need to specify this flag.
"""))
# Flags for base container image
parser.add_argument(
'--executor-image-uri',
metavar='IMAGE_URI',
required=True,
suggestion_aliases=['--base-image'],
help=textwrap.dedent("""
URI or ID of the container image in either the Container Registry or local
that will run the application.
See https://cloud.google.com/vertex-ai/docs/training/pre-built-containers
for available pre-built container images provided by Vertex AI for training.
"""))
# Flags for extra requirements.
parser.add_argument(
'--requirements',
metavar='REQUIREMENTS',
type=arg_parsers.ArgList(),
help=textwrap.dedent("""
Python dependencies from PyPI to be used when running the application.
If this is not specified, and there is no "setup.py" or "requirements.txt"
in the working directory, your application will only have access to what
exists in the base image with on other dependencies.
Example:
'tensorflow-cpu, pandas==1.2.0, matplotlib>=3.0.2'
"""))
# Flags for extra dependency .
parser.add_argument(
'--extra-packages',
metavar='PACKAGE',
type=arg_parsers.ArgList(),
help=textwrap.dedent("""
Local paths to Python archives used as training dependencies in the image
container.
These can be absolute or relative paths. However, they have to be under
the work_dir; Otherwise, this tool will not be able to access it.
Example:
'dep1.tar.gz, ./downloads/dep2.whl'
"""))
# Flags for the output image
parser.add_argument(
'--output-image-uri',
metavar='OUTPUT_IMAGE',
help=textwrap.dedent("""
Uri of the custom container image to be built with the your application
packed in.
"""))
# Flaga for GPU support
parser.add_argument(
'--gpu', action='store_true', default=False, help='Enable to use GPU.')
# Flags for docker run
parser.add_argument(
'--docker-run-options',
metavar='DOCKER_RUN_OPTIONS',
hidden=True,
type=arg_parsers.ArgList(),
help=textwrap.dedent("""
Custom Docker run options to pass to image during execution.
For example, '--no-healthcheck, -a stdin'.
See https://docs.docker.com/engine/reference/commandline/run/#options for
more details.
"""))
# Flags for service account
parser.add_argument(
'--service-account-key-file',
metavar='ACCOUNT_KEY_FILE',
help=textwrap.dedent("""
The JSON file of a Google Cloud service account private key.
When specified, the corresponding service account will be used to
authenticate the local container to access Google Cloud services.
Note that the key file won't be copied to the container, it will be
mounted during running time.
"""))
# User custom flags.
parser.add_argument(
'args',
nargs=argparse.REMAINDER,
default=[],
help="""Additional user arguments to be forwarded to your application.""",
example=('$ {command} --script=my_run.sh --base-image=gcr.io/my/image '
'-- --my-arg bar --enable_foo'))

View File

@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for local mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import io
import os
import subprocess
import sys
from googlecloudsdk.core.util import files
def ExecuteCommand(cmd, input_str=None, file=None):
"""Executes shell commands in subprocess.
Executes the supplied command with the supplied standard input string, streams
the output to stdout, and returns the process's return code.
Args:
cmd: (List[str]) Strings to send in as the command.
input_str: (str) if supplied, it will be passed as stdin to the supplied
command. if None, stdin will get closed immediately.
file: optional file-like object (stream), the output from the executed
process's stdout will get sent to this stream. Defaults to sys.stdout.
Returns:
return code of the process
"""
if file is None:
file = sys.stdout
with subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=False,
bufsize=1) as p:
if input_str:
p.stdin.write(input_str.encode('utf-8'))
p.stdin.close()
out = io.TextIOWrapper(p.stdout, newline='')
for line in out:
file.write(line)
file.flush()
else:
# Flush to force the contents to display.
file.flush()
return p.returncode
def ModuleToPath(module_name):
"""Converts the supplied python module into corresponding python file.
Args:
module_name: (str) A python module name (separated by dots)
Returns:
A string representing a python file path.
"""
return module_name.replace('.', os.path.sep) + '.py'
def ClearPyCache(root_dir=None):
"""Removes generic `__pycache__` folder and '*.pyc' '*.pyo' files."""
root_dir = root_dir or files.GetCWD()
is_cleaned = False
for name in os.listdir(root_dir):
item = os.path.join(root_dir, name)
if os.path.isdir(item):
if name == '__pycache__':
files.RmTree(item)
is_cleaned = True
else:
_, ext = os.path.splitext(name)
if ext in ['.pyc', '.pyo']:
os.remove(item)
is_cleaned = True
return is_cleaned

View File

@@ -0,0 +1,376 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Validations of the arguments of custom-jobs command group."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import validation
from googlecloudsdk.command_lib.ai.custom_jobs import local_util
from googlecloudsdk.command_lib.ai.docker import utils as docker_utils
from googlecloudsdk.core.util import files
def ValidateRegion(region):
"""Validate whether the given region is allowed for specifically custom job."""
validation.ValidateRegion(
region, available_regions=constants.SUPPORTED_TRAINING_REGIONS)
def ValidateCreateArgs(args, job_spec_from_config, version):
"""Validate the argument values specified in `create` command."""
# TODO(b/186082396): Add more validations for other args.
if args.worker_pool_spec:
_ValidateWorkerPoolSpecArgs(args.worker_pool_spec, version)
else:
_ValidateWorkerPoolSpecsFromConfig(job_spec_from_config)
def _ValidateWorkerPoolSpecArgs(worker_pool_specs, version):
"""Validates the argument values specified via `--worker-pool-spec` flags.
Args:
worker_pool_specs: List[dict], a list of worker pool specs specified in
command line.
version: str, the API version this command will interact with, either GA or
BETA.
"""
if not worker_pool_specs[0]:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
'Empty value is not allowed for the first `--worker-pool-spec` flag.')
_ValidateHardwareInWorkerPoolSpecArgs(worker_pool_specs, version)
_ValidateSoftwareInWorkerPoolSpecArgs(worker_pool_specs)
def _ValidateHardwareInWorkerPoolSpecArgs(worker_pool_specs, api_version):
"""Validates the hardware related fields specified in `--worker-pool-spec` flags.
Args:
worker_pool_specs: List[dict], a list of worker pool specs specified in
command line.
api_version: str, the API version this command will interact with, either GA
or BETA.
"""
for spec in worker_pool_specs:
if spec:
if 'machine-type' not in spec:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
'Key [machine-type] required in dict arg but not provided.')
if 'accelerator-count' in spec and 'accelerator-type' not in spec:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
'Key [accelerator-type] required as [accelerator-count] is specified.'
)
accelerator_type = spec.get('accelerator-type', None)
if accelerator_type:
type_enum = api_util.GetMessage(
'MachineSpec', api_version).AcceleratorTypeValueValuesEnum
valid_types = [
type for type in type_enum.names()
if type.startswith('NVIDIA') or type.startswith('TPU')
]
if accelerator_type not in valid_types:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
('Found invalid value of [accelerator-type]: {actual}. '
'Available values are [{expected}].').format(
actual=accelerator_type,
expected=', '.join(v for v in sorted(valid_types))))
def _ValidateSoftwareInWorkerPoolSpecArgs(worker_pool_specs):
"""Validates the software fields specified in all `--worker-pool-spec` flags."""
has_local_package = _ValidateSoftwareInFirstWorkerPoolSpec(
worker_pool_specs[0])
if len(worker_pool_specs) > 1:
_ValidateSoftwareInRestWorkerPoolSpecs(worker_pool_specs[1:],
has_local_package)
def _ValidateSoftwareInFirstWorkerPoolSpec(spec):
"""Validates the software related fields specified in the first `--worker-pool-spec` flags.
Args:
spec: dict, the specification of the first worker pool.
Returns:
A boolean value whether a local package will be used.
"""
if 'local-package-path' in spec:
_ValidateWorkerPoolSoftwareWithLocalPackage(spec)
return True
else:
_ValidateWorkerPoolSoftwareWithoutLocalPackages(spec)
return False
def _ValidateSoftwareInRestWorkerPoolSpecs(specs,
is_local_package_specified=False):
"""Validates the argument values specified in all but the first `--worker-pool-spec` flags.
Args:
specs: List[dict], the list all but the first worker pool specs specified in
command line.
is_local_package_specified: bool, whether local package is specified
in the first worker pool.
"""
for spec in specs:
if spec:
if is_local_package_specified:
# No more software allowed
software_fields = {
'executor-image-uri',
'container-image-uri',
'python-module',
'script',
'requirements',
'extra-packages',
'extra-dirs',
}
_RaiseErrorIfUnexpectedKeys(
unexpected_keys=software_fields.intersection(spec.keys()),
reason=('A local package has been specified in the first '
'`--worker-pool-spec` flag and to be used for all workers, '
'do not specify these keys elsewhere.'))
else:
if 'local-package-path' in spec:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
('Key [local-package-path] is only allowed in the first '
'`--worker-pool-spec` flag.'))
_ValidateWorkerPoolSoftwareWithoutLocalPackages(spec)
def _ValidateWorkerPoolSoftwareWithLocalPackage(spec):
"""Validate the software in a single `--worker-pool-spec` when `local-package-path` is specified."""
assert 'local-package-path' in spec
_RaiseErrorIfNotExists(
spec['local-package-path'], flag_name='--worker-pool-spec')
if 'executor-image-uri' not in spec:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
'Key [executor-image-uri] is required when `local-package-path` is specified.'
)
if ('python-module' in spec) + ('script' in spec) != 1:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
'Exactly one of keys [python-module, script] is required '
'when `local-package-path` is specified.')
if 'output-image-uri' in spec:
output_image = spec['output-image-uri']
hostname = output_image.split('/')[0]
container_registries = ['gcr.io', 'eu.gcr.io', 'asia.gcr.io', 'us.gcr.io']
if hostname not in container_registries and not hostname.endswith(
'-docker.pkg.dev'):
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
'The value of `output-image-uri` has to be a valid gcr.io or Artifact Registry image'
)
try:
docker_utils.ValidateRepositoryAndTag(output_image)
except ValueError as e:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
r"'{}' is not a valid container image uri: {}".format(
output_image, e))
def _ValidateWorkerPoolSoftwareWithoutLocalPackages(spec):
"""Validate the software in a single `--worker-pool-spec` when `local-package-path` is not specified."""
assert 'local-package-path' not in spec
has_executor_image = 'executor-image-uri' in spec
has_container_image = 'container-image-uri' in spec
has_python_module = 'python-module' in spec
if (has_executor_image + has_container_image) != 1:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
('Exactly one of keys [executor-image-uri, container-image-uri] '
'is required.'))
if has_container_image and has_python_module:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec',
('Key [python-module] is not allowed together with key '
'[container-image-uri].'))
if has_executor_image and not has_python_module:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec', 'Key [python-module] is required.')
local_package_only_keys = {
'script',
'requirements',
'extra-packages',
'extra-dirs',
}
unexpected_keys = local_package_only_keys.intersection(spec.keys())
_RaiseErrorIfUnexpectedKeys(
unexpected_keys,
reason='Only allow to specify together with `local-package-path` in the first `--worker-pool-spec` flag'
)
def _RaiseErrorIfUnexpectedKeys(unexpected_keys, reason):
if unexpected_keys:
raise exceptions.InvalidArgumentException(
'--worker-pool-spec', 'Keys [{keys}] are not allowed: {reason}.'.format(
keys=', '.join(sorted(unexpected_keys)), reason=reason))
def _ValidateWorkerPoolSpecsFromConfig(job_spec):
"""Validate WorkerPoolSpec message instances imported from the config file."""
# TODO(b/186082396): adds more validations for other fields.
for spec in job_spec.workerPoolSpecs:
use_python_package = spec.pythonPackageSpec and (
spec.pythonPackageSpec.executorImageUri or
spec.pythonPackageSpec.pythonModule)
use_container = spec.containerSpec and spec.containerSpec.imageUri
if (use_container and use_python_package) or (not use_container and
not use_python_package):
raise exceptions.InvalidArgumentException(
'--config',
('Exactly one of fields [pythonPackageSpec, containerSpec] '
'is required for a [workerPoolSpecs] in the YAML config file.'))
def _ImageBuildArgSpecified(args):
"""Returns names of all the flags specified only for image building."""
image_build_args = []
if args.script:
image_build_args.append('script')
if args.python_module:
image_build_args.append('python-module')
if args.requirements:
image_build_args.append('requirements')
if args.extra_packages:
image_build_args.append('extra-packages')
if args.extra_dirs:
image_build_args.append('extra-dirs')
if args.output_image_uri:
image_build_args.append('output-image-uri')
return image_build_args
def _ValidBuildArgsOfLocalRun(args):
"""Validates the arguments related to image building and normalize them."""
build_args_specified = _ImageBuildArgSpecified(args)
if not build_args_specified:
return
if not args.script and not args.python_module:
raise exceptions.MinimumArgumentException(
['--script', '--python-module'],
'They are required to build a training container image. '
'Otherwise, please remove flags [{}] to directly run the `executor-image-uri`.'
.format(', '.join(sorted(build_args_specified))))
# Validate main script's existence:
if args.script:
arg_name = '--script'
else:
args.script = local_util.ModuleToPath(args.python_module)
arg_name = '--python-module'
script_path = os.path.normpath(
os.path.join(args.local_package_path, args.script))
if not os.path.exists(script_path) or not os.path.isfile(script_path):
raise exceptions.InvalidArgumentException(
arg_name, r"File '{}' is not found under the package: '{}'.".format(
args.script, args.local_package_path))
# Validate extra custom packages specified:
for package in (args.extra_packages or []):
package_path = os.path.normpath(
os.path.join(args.local_package_path, package))
if not os.path.exists(package_path) or not os.path.isfile(package_path):
raise exceptions.InvalidArgumentException(
'--extra-packages',
r"Package file '{}' is not found under the package: '{}'.".format(
package, args.local_package_path))
# Validate extra directories specified:
for directory in (args.extra_dirs or []):
dir_path = os.path.normpath(
os.path.join(args.local_package_path, directory))
if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
raise exceptions.InvalidArgumentException(
'--extra-dirs',
r"Directory '{}' is not found under the package: '{}'.".format(
directory, args.local_package_path))
# Validate output image uri is in valid format
if args.output_image_uri:
output_image = args.output_image_uri
try:
docker_utils.ValidateRepositoryAndTag(output_image)
except ValueError as e:
raise exceptions.InvalidArgumentException(
'--output-image-uri',
r"'{}' is not a valid container image uri: {}".format(
output_image, e))
else:
args.output_image_uri = docker_utils.GenerateImageName(
base_name=args.script)
def ValidateLocalRunArgs(args):
"""Validates the arguments specified in `local-run` command and normalize them."""
args_local_package_pach = args.local_package_path
if args_local_package_pach:
work_dir = os.path.abspath(files.ExpandHomeDir(args_local_package_pach))
if not os.path.exists(work_dir) or not os.path.isdir(work_dir):
raise exceptions.InvalidArgumentException(
'--local-package-path',
r"Directory '{}' is not found.".format(work_dir))
else:
work_dir = files.GetCWD()
args.local_package_path = work_dir
_ValidBuildArgsOfLocalRun(args)
return args
def _RaiseErrorIfNotExists(local_package_path, flag_name):
"""Validate the local package is valid.
Args:
local_package_path: str, path of the local directory to check.
flag_name: str, indicates in which flag the path is specified.
"""
work_dir = os.path.abspath(files.ExpandHomeDir(local_package_path))
if not os.path.exists(work_dir) or not os.path.isdir(work_dir):
raise exceptions.InvalidArgumentException(
flag_name, r"Directory '{}' is not found.".format(work_dir))

View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform deployment resource pools commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import resources
def ParseOperation(operation_name):
"""Parse operation resource to the operation reference object.
Args:
operation_name: The operation resource to wait on
Returns:
The operation reference object
"""
if '/deploymentResourcePools/' in operation_name:
try:
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection=
'aiplatform.projects.locations.deploymentResourcePools.operations')
except resources.WrongResourceCollectionException:
pass
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.operations')

View File

@@ -0,0 +1,338 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions required to interact with Docker to build images."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
import os
import posixpath
import re
import textwrap
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai.custom_jobs import local_util
from googlecloudsdk.command_lib.ai.docker import utils
from googlecloudsdk.core import log
from six.moves import shlex_quote
_DEFAULT_HOME = "/home"
_DEFAULT_WORKDIR = "/usr/app"
_DEFAULT_SETUP_PATH = "./setup.py"
_DEFAULT_REQUIREMENTS_PATH = "./requirements.txt"
_AUTONAME_PREFIX = "cloudai-autogenerated"
_AUTOGENERATED_TAG_LENGTH = 16
def _IsVertexTrainingPrebuiltImage(image_name):
"""Checks whether the image is pre-built by Vertex AI training."""
prebuilt_image_name_regex = (r"^(us|europe|asia)-docker.pkg.dev/"
r"vertex-ai/training/"
r"(tf|scikit-learn|pytorch|xgboost)-.+$")
return re.fullmatch(prebuilt_image_name_regex, image_name) is not None
def _SitecustomizeRemovalEntry(is_prebuilt_image):
"""Returns a Dockerfile entry that removes `sitecustomize` if it's Vertex AI Training pre-built container images."""
return "RUN rm -rf /var/sitecustomize" if is_prebuilt_image else ""
def _GenerateCopyCommand(from_path, to_path, comment=None):
"""Returns a Dockerfile entry that copies a file from host to container.
Args:
from_path: (str) Path of the source in host.
to_path: (str) Path to the destination in the container.
comment: (str) A comment explaining the copy operation.
"""
cmd = "COPY {}\n".format(json.dumps([from_path, to_path]))
if comment is not None:
formatted_comment = "\n# ".join(comment.split("\n"))
return "# {}\n{}".format(formatted_comment, cmd)
return cmd
def _DependencyEntries(is_prebuilt_image=False,
requirements_path=None,
setup_path=None,
extra_requirements=None,
extra_packages=None,
extra_dirs=None):
"""Returns the Dockerfile entries required to install dependencies.
Args:
is_prebuilt_image: (bool) Whether the base image is pre-built and provided
by Vertex AI.
requirements_path: (str) Path that points to a requirements.txt file
setup_path: (str) Path that points to a setup.py
extra_requirements: (List[str]) Required dependencies to be installed from
remote resource archives.
extra_packages: (List[str]) User custom dependency packages to install.
extra_dirs: (List[str]) Directories other than the work_dir required.
"""
ret = ""
pip_version = "pip3" if is_prebuilt_image else "pip"
if setup_path is not None:
ret += textwrap.dedent("""
{}
RUN {} install --no-cache-dir .
""".format(
_GenerateCopyCommand(
setup_path,
"./setup.py",
comment="Found setup.py file, thus copy it to the docker container."
), pip_version))
if requirements_path is not None:
ret += textwrap.dedent("""
{}
RUN {} install --no-cache-dir -r ./requirements.txt
""".format(
_GenerateCopyCommand(
requirements_path,
"./requirements.txt",
comment="Found requirements.txt file, thus to the docker container."
), pip_version))
if extra_packages is not None:
for extra in extra_packages:
package_name = os.path.basename(extra)
ret += textwrap.dedent("""
{}
RUN {} install --no-cache-dir {}
""".format(
_GenerateCopyCommand(extra, package_name), pip_version,
shlex_quote(package_name)))
if extra_requirements is not None:
for requirement in extra_requirements:
ret += textwrap.dedent("""
RUN {} install --no-cache-dir --upgrade {}
""".format(pip_version, shlex_quote(requirement)))
if extra_dirs is not None:
for directory in extra_dirs:
ret += "\n{}\n".format(_GenerateCopyCommand(directory, directory))
return ret
def _GenerateEntrypoint(package, is_prebuilt_image=False):
"""Generates dockerfile entry to set the container entrypoint.
Args:
package: (Package) Represents the main application copied to the container.
is_prebuilt_image: (bool) Whether the base image is pre-built and provided
by Vertex AI.
Returns:
A string with Dockerfile directives to set ENTRYPOINT
"""
# Make it consistent with Online python package training that python3
# has been installed for all prebuilt images and used by default
python_command = "python3" if is_prebuilt_image else "python"
# Needs to use json so that quotes print as double quotes, not single quotes.
if package.python_module is not None:
exec_str = json.dumps([python_command, "-m", package.python_module])
else:
_, ext = os.path.splitext(package.script)
executable = [python_command] if ext == ".py" else ["/bin/bash"]
exec_str = json.dumps(executable + [package.script])
return "\nENTRYPOINT {}".format(exec_str)
def _PreparePackageEntry(package):
"""Returns the Dockerfile entries required to append at the end before entrypoint.
Including:
- copy the parent directory of the main executable into a docker container.
- inject an entrypoint that executes a script or python module inside that
directory.
Args:
package: (Package) Represents the main application copied to and run in the
container.
"""
parent_dir = os.path.dirname(package.script) or "."
copy_code = _GenerateCopyCommand(
parent_dir,
parent_dir,
comment="Copy the source directory into the docker container.")
return "\n{}\n".format(copy_code)
def _MakeDockerfile(base_image,
main_package,
container_workdir,
container_home,
requirements_path=None,
setup_path=None,
extra_requirements=None,
extra_packages=None,
extra_dirs=None):
"""Generates a Dockerfile for building an image.
It builds on a specified base image to create a container that:
- installs any dependency specified in a requirements.txt or a setup.py file,
and any specified dependency packages existing locally or found from PyPI
- copies all source needed by the main module, and potentially injects an
entrypoint that, on run, will run that main module
Args:
base_image: (str) ID or name of the base image to initialize the build
stage.
main_package: (Package) Represents the main application to execute.
container_workdir: (str) Working directory in the container.
container_home: (str) $HOME directory in the container.
requirements_path: (str) Rath of a requirements.txt file.
setup_path: (str) Path of a setup.py file
extra_requirements: (List[str]) Required dependencies to install from PyPI.
extra_packages: (List[str]) User custom dependency packages to install.
extra_dirs: (List[str]) Directories other than the work_dir required to be
in the container.
Returns:
A string that represents the content of a Dockerfile.
"""
is_training_prebuilt_image_base = _IsVertexTrainingPrebuiltImage(base_image)
dockerfile = textwrap.dedent("""
FROM {base_image}
# The directory is created by root. This sets permissions so that any user can
# access the folder.
RUN mkdir -m 777 -p {workdir} {container_home}
WORKDIR {workdir}
ENV HOME={container_home}
# Keeps Python from generating .pyc files in the container
ENV PYTHONDONTWRITEBYTECODE=1
""".format(
base_image=base_image,
workdir=shlex_quote(container_workdir),
container_home=shlex_quote(container_home)))
dockerfile += _SitecustomizeRemovalEntry(is_training_prebuilt_image_base)
dockerfile += _DependencyEntries(
is_training_prebuilt_image_base,
requirements_path=requirements_path,
setup_path=setup_path,
extra_requirements=extra_requirements,
extra_packages=extra_packages,
extra_dirs=extra_dirs)
dockerfile += _PreparePackageEntry(main_package)
dockerfile += _GenerateEntrypoint(main_package,
is_training_prebuilt_image_base)
return dockerfile
def BuildImage(base_image,
host_workdir,
main_script,
output_image_name,
python_module=None,
requirements=None,
extra_packages=None,
container_workdir=None,
container_home=None,
no_cache=True,
**kwargs):
"""Builds a Docker image.
Generates a Dockerfile and passes it to `docker build` via stdin.
All output from the `docker build` process prints to stdout.
Args:
base_image: (str) ID or name of the base image to initialize the build
stage.
host_workdir: (str) A path indicating where all the required sources
locates.
main_script: (str) A string that identifies the executable script under the
working directory.
output_image_name: (str) Name of the built image.
python_module: (str) Represents the executable main_script in form of a
python module, if applicable.
requirements: (List[str]) Required dependencies to install from PyPI.
extra_packages: (List[str]) User custom dependency packages to install.
container_workdir: (str) Working directory in the container.
container_home: (str) the $HOME directory in the container.
no_cache: (bool) Do not use cache when building the image.
**kwargs: Other arguments to pass to underlying method that generates the
Dockerfile.
Returns:
A Image class that contains info of the built image.
Raises:
DockerError: An error occurred when executing `docker build`
"""
tag_options = ["-t", output_image_name]
cache_args = ["--no-cache"] if no_cache else []
command = ["docker", "build"
] + cache_args + tag_options + ["--rm", "-f-", host_workdir]
has_setup_py = os.path.isfile(os.path.join(host_workdir, _DEFAULT_SETUP_PATH))
setup_path = _DEFAULT_SETUP_PATH if has_setup_py else None
has_requirements_txt = os.path.isfile(
os.path.join(host_workdir, _DEFAULT_REQUIREMENTS_PATH))
requirements_path = _DEFAULT_REQUIREMENTS_PATH if has_requirements_txt else None
home_dir = container_home or _DEFAULT_HOME
work_dir = container_workdir or _DEFAULT_WORKDIR
# The package will be used in Docker, thus norm it to POSIX path format.
main_package = utils.Package(
script=main_script.replace(os.sep, posixpath.sep),
package_path=host_workdir.replace(os.sep, posixpath.sep),
python_module=python_module)
dockerfile = _MakeDockerfile(
base_image,
main_package=main_package,
container_home=home_dir,
container_workdir=work_dir,
requirements_path=requirements_path,
setup_path=setup_path,
extra_requirements=requirements,
extra_packages=extra_packages,
**kwargs)
joined_command = " ".join(command)
log.info("Running command: {}".format(joined_command))
return_code = local_util.ExecuteCommand(command, input_str=dockerfile)
if return_code == 0:
return utils.Image(output_image_name, home_dir, work_dir)
else:
error_msg = textwrap.dedent("""
Docker failed with error code {code}.
Command: {cmd}
""".format(code=return_code, cmd=joined_command))
raise errors.DockerError(error_msg, command, return_code)

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions required to interact with Docker to run a container."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.ai.docker import utils
from googlecloudsdk.core import config
_DEFAULT_CONTAINER_CRED_KEY_PATH = "/tmp/keys/cred_key.json"
def _DockerRunOptions(enable_gpu=False,
service_account_key=None,
cred_mount_path=_DEFAULT_CONTAINER_CRED_KEY_PATH,
extra_run_opts=None):
"""Returns a list of 'docker run' options.
Args:
enable_gpu: (bool) using GPU or not.
service_account_key: (bool) path of the service account key to use in host.
cred_mount_path: (str) path in the container to mount the credential key.
extra_run_opts: (List[str]) other custom docker run options.
"""
if extra_run_opts is None:
extra_run_opts = []
runtime = ["--runtime", "nvidia"] if enable_gpu else []
if service_account_key:
mount = ["-v", "{}:{}".format(service_account_key, cred_mount_path)]
else:
# Calls Application Default Credential (ADC),
adc_file_path = config.ADCEnvVariable() or config.ADCFilePath()
mount = ["-v", "{}:{}".format(adc_file_path, cred_mount_path)]
env_var = ["-e", "GOOGLE_APPLICATION_CREDENTIALS={}".format(cred_mount_path)]
return ["--rm"] + runtime + mount + env_var + ["--ipc", "host"
] + extra_run_opts
def RunContainer(image_name,
enable_gpu=False,
service_account_key=None,
run_args=None,
user_args=None):
"""Calls `docker run` on a given image with specified arguments.
Args:
image_name: (str) Name or ID of Docker image to run.
enable_gpu: (bool) Whether to use GPU
service_account_key: (str) Json file of a service account key auth.
run_args: (List[str]) Extra custom options to apply to `docker run` after
our defaults.
user_args: (List[str]) Extra user defined arguments to supply to the
entrypoint.
"""
# TODO(b/177787660): add interactive mode option
if run_args is None:
run_args = []
if user_args is None:
user_args = []
run_opts = _DockerRunOptions(
enable_gpu=enable_gpu,
service_account_key=service_account_key,
extra_run_opts=run_args)
command = ["docker", "run"] + run_opts + [image_name] + user_args
utils.ExecuteDockerCommand(command)

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities to operate with Docker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import collections
import datetime
import re
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai.custom_jobs import local_util
from googlecloudsdk.core import log
_MAX_REPOSITORY_LENGTH = 255
_MAX_TAG_LENGTH = 128
_AUTONAME_PREFIX = "cloudai-autogenerated"
_DEFAULT_IMAGE_NAME = "unnamed"
_DEFAULT_REPO_REGION = "us"
Package = collections.namedtuple("Package",
["script", "package_path", "python_module"])
Image = collections.namedtuple("Image",
["name", "default_home", "default_workdir"])
def _ParseRepositoryTag(image_name):
"""Parses out the repository and tag from a Docker image name.
Args:
image_name: (str) The full name of an image, expected to be in a format of
"repository[:tag]"
Returns:
A (repository, tag) tuple representing the parsed result.
None repository means the image name is invalid; tag may be None if it isn't
present in the given image name.
"""
if image_name.count(":") > 2:
return None, None
parts = image_name.rsplit(":", 1)
if len(parts) == 2 and "/" not in parts[1]:
return tuple(parts)
return image_name, None
def _ParseRepositoryHost(repository_name):
"""Parses a repository to an optional hostname and a list of path compoentes.
Args:
repository_name: (str) A name made up of slash-separated path name
components, optionally prefixed by a registry hostname.
Returns:
A (hostname, components) tuple representing the parsed result.
The hostname will be None if it isn't present; the components is a list of
each slash-separated part in the given repository name.
"""
components = repository_name.split("/")
if len(components) == 1:
return None, components
if "." in components[0] or ":" in components[0]:
# components[0] is regarded as a hostname
return components[0], components[1:]
return None, components
def _ParseHostPort(host):
"""Parses a registry hostname to a list of components and an optional port.
Args:
host: (str) The registry hostname supposed to comply with standard DNS
rules, optionally be followed by a port number in the format like ":8080".
Returns:
A (hostcomponents, port) tuple representing the parsed result.
The hostcomponents contains each dot-seperated component in the given
hostname; port may be None if it isn't present.
"""
parts = host.rsplit(":", 1)
hostcomponents = parts[0].split(".")
port = parts[1] if len(parts) == 2 else None
return hostcomponents, port
def ValidateRepositoryAndTag(image_name):
r"""Validate the given image name is a valid repository/tag reference.
As explained in
https://docs.docker.com/engine/reference/commandline/tag/#extended-description,
a valid repository/tag reference should following the below pattern:
reference := name [ ":" tag ]
name := [hostname '/'] component ['/' component]*
hostname := hostcomponent ['.' hostcomponent]* [':' port-number]
hostcomponent := /([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9-]*[a-zA-Z0-9])/
port-number := /[0-9]+/
component := alpha-numeric [separator alpha-numeric]*
alpha-numeric := /[a-z0-9]+/
separator := /[_.]|__|[-]*/
tag := /[\w][\w.-]{0,127}/
Args:
image_name: (str) Full name of a Docker image.
Raises:
ValueError if the image name is not valid.
"""
repository, tag = _ParseRepositoryTag(image_name)
if repository is None:
raise ValueError("Unable to parse repository and tag.")
if len(repository) > _MAX_REPOSITORY_LENGTH:
raise ValueError(
"Repository name must not be more than {} characters.".format(
_MAX_REPOSITORY_LENGTH))
hostname, path_components = _ParseRepositoryHost(repository)
if hostname:
hostcomponents, port = _ParseHostPort(hostname)
hostcomponent_regex = r"^(?:[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9-]*[a-zA-Z0-9])$"
for hostcomponent in hostcomponents:
if re.match(hostcomponent_regex, hostcomponent) is None:
raise ValueError(
"Invalid hostname/port \"{}\" in repository name.".format(hostname))
port_regex = r"^[0-9]+$"
if port and re.match(port_regex, port) is None:
raise ValueError(
"Invalid hostname/port \"{}\" in repository name.".format(hostname))
for component in path_components:
if not component:
raise ValueError("Empty path component in repository name.")
component_regex = r"^[a-z0-9]+(?:(?:[._]|__|[-]*)[a-z0-9]+)*$"
if re.match(component_regex, component) is None:
raise ValueError(
"Invalid path component \"{}\" in repository name.".format(component))
if tag:
if len(tag) > _MAX_TAG_LENGTH:
raise ValueError("Tag name must not be more than {} characters.".format(
_MAX_TAG_LENGTH))
tag_regex = r"^[\w][\w.-]{0,127}$"
if re.match(tag_regex, tag) is None:
raise ValueError("Invalid tag.")
def GenerateImageName(base_name=None, project=None, region=None, is_gcr=False):
"""Generate a name for the Docker image built by AI platform gcloud."""
sanitized_name = _SanitizeRepositoryName(base_name or _DEFAULT_IMAGE_NAME)
# Use the current timestamp as the tag.
tag = datetime.datetime.now().strftime("%Y%m%d.%H.%M.%S.%f")
image_name = "{}/{}:{}".format(_AUTONAME_PREFIX, sanitized_name, tag)
if project:
if is_gcr:
repository = "gcr.io"
else:
region_prefix = region or _DEFAULT_REPO_REGION
repository = "{}-docker.pkg.dev".format(region_prefix)
return "{}/{}/{}".format(repository, project.replace(":", "/"), image_name)
return image_name
def _SanitizeRepositoryName(name):
"""Sanitizes the given name to make it valid as an image repository.
As explained in
https://docs.docker.com/engine/reference/commandline/tag/#extended-description,
Valid name may contain only lowercase letters, digits and separators.
A separator is defined as a period, one or two underscores, or one or more
dashes. A name component may not start or end with a separator.
This method will replace the illegal characters in the given name and strip
starting and ending separator characters.
Args:
name: str, the name to sanitize.
Returns:
A sanitized name.
"""
return re.sub("[._][._]+|[^a-z0-9._-]+", ".", name.lower()).strip("._-")
def ExecuteDockerCommand(command):
"""Executes Docker CLI commands in subprocess.
Just calls local_util.ExecuteCommand(cmd,...) and raises error for non-zero
exit code.
Args:
command: (List[str]) Strings to send in as the command.
Raises:
ValueError: The input command is not a docker command.
DockerError: An error occurred when executing the given docker command.
"""
command_str = " ".join(command)
if not command_str.startswith("docker"):
raise ValueError("`{}` is not a Docker command".format("docker"))
log.info("Running command: {}".format(command_str))
return_code = local_util.ExecuteCommand(command)
if return_code != 0:
error_msg = """
Docker failed with error code {code}.
Command: {cmd}
""".format(
code=return_code, cmd=command_str)
raise errors.DockerError(error_msg, command, return_code)

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for operating on endpoints for different regions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import contextlib
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from six.moves.urllib import parse
def DeriveAiplatformRegionalEndpoint(endpoint, region, is_prediction=False):
"""Adds region as a prefix of the base url."""
scheme, netloc, path, params, query, fragment = parse.urlparse(endpoint)
if netloc.startswith('aiplatform'):
if is_prediction:
netloc = '{}-prediction-{}'.format(region, netloc)
else:
netloc = '{}-{}'.format(region, netloc)
return parse.urlunparse((scheme, netloc, path, params, query, fragment))
@contextlib.contextmanager
def AiplatformEndpointOverrides(version, region, is_prediction=False):
"""Context manager to override the AI Platform endpoints for a while.
Raises an error if
region is not set.
Args:
version: str, implies the version that the endpoint will use.
region: str, region of the AI Platform stack.
is_prediction: bool, it's for prediction endpoint or not.
Yields:
None
"""
used_endpoint = GetEffectiveEndpoint(version=version, region=region,
is_prediction=is_prediction)
log.status.Print('Using endpoint [{}]'.format(used_endpoint))
properties.VALUES.api_endpoint_overrides.aiplatform.Set(used_endpoint)
yield
def GetEffectiveEndpoint(version, region, is_prediction=False):
"""Returns regional AI Platform endpoint, or raise an error if the region not set."""
endpoint = apis.GetEffectiveApiEndpoint(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
return DeriveAiplatformRegionalEndpoint(
endpoint, region, is_prediction=is_prediction)

View File

@@ -0,0 +1,140 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform endpoints commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import io
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
def ParseOperation(operation_name):
"""Parse operation resource to the operation reference object.
Args:
operation_name: The operation resource to wait on
Returns:
The operation reference object
"""
if '/endpoints/' in operation_name:
try:
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.endpoints.operations',
)
except resources.WrongResourceCollectionException:
pass
return resources.REGISTRY.ParseRelativeName(
operation_name, collection='aiplatform.projects.locations.operations'
)
def _LoadYaml(file_path, sdk_method):
"""Loads a YAML file."""
data = console_io.ReadFromFileOrStdin(file_path, binary=True)
with io.BytesIO(data) as f:
try:
return yaml.load(f)
except ValueError:
raise errors.InvalidInstancesFileError(
f'Input is not in JSON format. See `gcloud ai endpoints {sdk_method}'
' --help` for details.'
)
def ReadInstancesFromArgs(json_request):
"""Reads the instances from the given file path ('-' for stdin).
Args:
json_request: str or None, a path to a file ('-' for stdin) containing the
JSON body of a prediction request.
Returns:
A list of instances.
Raises:
InvalidInstancesFileError: If the input file is invalid (invalid format or
contains too many/zero instances), or an improper combination of input
files was given.
"""
request = _LoadYaml(json_request, sdk_method='predict')
if not isinstance(request, dict):
raise errors.InvalidInstancesFileError(
'Input instances are not in JSON format. '
'See `gcloud ai endpoints predict --help` for details.'
)
if 'instances' not in request:
raise errors.InvalidInstancesFileError(
'Invalid JSON request: missing "instances" attribute'
)
if not isinstance(request['instances'], list):
raise errors.InvalidInstancesFileError(
'Invalid JSON request: "instances" must be a list'
)
return request
def ReadInputsFromArgs(json_request):
"""Validates and reads json request for Direct Prediction."""
request = _LoadYaml(json_request, sdk_method='direct-predict')
if 'inputs' not in request:
raise errors.InvalidInstancesFileError('Input json must contain "inputs"')
return request
def ReadInputFromArgs(json_request):
"""Validates and reads json request for Direct Raw Prediction."""
request = _LoadYaml(json_request, sdk_method='direct-raw-predict')
if 'input' not in request:
raise errors.InvalidInstancesFileError('Input json must contain "input"')
if 'method_name' not in request and 'methodName' not in request:
raise errors.InvalidInstancesFileError(
'Input json must contain "method_name" or "methodName"'
)
return request
def GetDefaultFormat(predictions, key_name='predictions'):
"""Get default output format for prediction results."""
if not isinstance(predictions, list):
# This usually indicates some kind of error case, so surface the full API
# response
return 'json'
elif not predictions:
return None
# predictions is guaranteed by API contract to be a list of similarly shaped
# objects, but we don't know ahead of time what those objects look like.
elif isinstance(predictions[0], dict):
keys = ', '.join(sorted(predictions[0].keys()))
return """
table(
{}:format="table(
{}
)"
)""".format(key_name, keys)
else:
return 'table[no-heading]({})'.format(key_name)

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition for errors in AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import exceptions
class ArgumentError(exceptions.Error):
pass
class InvalidInstancesFileError(exceptions.Error):
"""Indicates that the input file was invalid in some way."""
pass
class NoFieldsSpecifiedError(exceptions.Error):
"""Error indicating that no updates were requested in a Patch operation."""
pass
class DockerError(exceptions.Error):
"""Exception that passes info on a failed Docker command."""
def __init__(self, message, cmd, exit_code):
super(DockerError, self).__init__(message)
self.message = message
self.cmd = cmd
self.exit_code = exit_code

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flag definitions specifically for gcloud ai hp-tuning-jobs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope.concepts import concepts
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags
from googlecloudsdk.command_lib.ai import region_util
from googlecloudsdk.command_lib.ai.hp_tuning_jobs import hp_tuning_jobs_util
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.concepts import concept_parsers
_HPTUNING_JOB_DISPLAY_NAME = base.Argument(
'--display-name',
required=True,
help=('Display name of the hyperparameter tuning job to create.'))
# The parameter max-trial-count and parallel-trial-count can be set through
# command line or config.yaml file. Setting the values to be None to indicate
# the value is not set through command line by the customers. If both command
# line and config.yaml file don't set the field, we set it to be 1.
_HPTUNING_MAX_TRIAL_COUNT = base.Argument(
'--max-trial-count',
type=int,
default=None,
help='Desired total number of trials. The default value is 1.')
_HPTUNING_PARALLEL_TRIAL_COUNT = base.Argument(
'--parallel-trial-count',
type=int,
default=None,
help='Desired number of Trials to run in parallel. The default value is 1.')
_HPTUNING_JOB_CONFIG = base.Argument(
'--config',
required=True,
help="""
Path to the job configuration file. This file should be a YAML document containing a HyperparameterTuningSpec.
If an option is specified both in the configuration file **and** via command line arguments, the command line arguments
override the configuration file.
Example(YAML):
displayName: TestHpTuningJob
maxTrialCount: 1
parallelTrialCount: 1
studySpec:
metrics:
- metricId: x
goal: MINIMIZE
parameters:
- parameterId: z
integerValueSpec:
minValue: 1
maxValue: 100
algorithm: RANDOM_SEARCH
trialJobSpec:
workerPoolSpecs:
- machineSpec:
machineType: n1-standard-4
replicaCount: 1
containerSpec:
imageUri: gcr.io/ucaip-test/ucaip-training-test
""")
def AddCreateHpTuningJobFlags(parser, algorithm_enum):
"""Adds arguments for creating hp tuning job."""
_HPTUNING_JOB_DISPLAY_NAME.AddToParser(parser)
_HPTUNING_JOB_CONFIG.AddToParser(parser)
_HPTUNING_MAX_TRIAL_COUNT.AddToParser(parser)
_HPTUNING_PARALLEL_TRIAL_COUNT.AddToParser(parser)
labels_util.AddCreateLabelsFlags(parser)
flags.AddRegionResourceArg(
parser,
'to create a hyperparameter tuning job',
prompt_func=region_util.GetPromptForRegionFunc(
constants.SUPPORTED_TRAINING_REGIONS))
flags.TRAINING_SERVICE_ACCOUNT.AddToParser(parser)
flags.NETWORK.AddToParser(parser)
flags.ENABLE_WEB_ACCESS.AddToParser(parser)
flags.ENABLE_DASHBOARD_ACCESS.AddToParser(parser)
flags.AddKmsKeyResourceArg(parser, 'hyperparameter tuning job')
arg_utils.ChoiceEnumMapper(
'--algorithm',
algorithm_enum,
help_str='Search algorithm specified for the given study. '
).choice_arg.AddToParser(parser)
def AddHptuningJobResourceArg(parser,
verb,
regions=constants.SUPPORTED_TRAINING_REGIONS):
"""Adds a resource argument for a Vertex AI hyperparameter tuning job.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
regions: list[str], the list of supported regions.
"""
job_resource_spec = concepts.ResourceSpec(
resource_collection=hp_tuning_jobs_util.HPTUNING_JOB_COLLECTION,
resource_name='hyperparameter tuning job',
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
locationsId=flags.RegionAttributeConfig(
prompt_func=region_util.GetPromptForRegionFunc(regions)),
disable_auto_completers=False)
concept_parsers.ConceptParser.ForResource(
'hptuning_job',
job_resource_spec,
'The hyperparameter tuning job {}.'.format(verb),
required=True).AddToParser(parser)

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Vertex AI hyperparameter tuning jobs commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
from googlecloudsdk.core import resources
HPTUNING_JOB_COLLECTION = 'aiplatform.projects.locations.hyperparameterTuningJobs'
def ParseJobName(name):
"""Parses the id from a full hyperparameter tuning job name."""
return resources.REGISTRY.Parse(
name, collection=HPTUNING_JOB_COLLECTION).Name()
def OutputCommandVersion(release_track):
if release_track == base.ReleaseTrack.GA:
return ''
elif release_track == base.ReleaseTrack.BETA:
return ' beta'
else:
return ' alpha'

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform index endpoints commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import resources
def BuildParentOperation(project_id, location_id, index_endpoint_id,
operation_id):
"""Build multi-parent operation."""
return ParseIndexEndpointOperation(
'projects/{}/locations/{}/indexEndpoints/{}/operations/{}'.format(
project_id, location_id, index_endpoint_id, operation_id))
def ParseIndexEndpointOperation(operation_name):
"""Parse operation relative resource name to the operation reference object.
Args:
operation_name: The operation resource name
Returns:
The operation reference object
"""
if '/indexEndpoints/' in operation_name:
try:
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.indexEndpoints.operations')
except resources.WrongResourceCollectionException:
pass
return resources.REGISTRY.ParseRelativeName(
operation_name, collection='aiplatform.projects.locations.operations')

View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform indexes commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import resources
def BuildIndexParentOperation(project_id, location_id, index_id, operation_id):
"""Build multi-parent operation."""
return ParseIndexOperation(
'projects/{}/locations/{}/indexes/{}/operations/{}'.format(
project_id, location_id, index_id, operation_id))
def ParseIndexOperation(operation_name):
"""Parse operation relative resource name to the operation reference object.
Args:
operation_name: The operation resource name
Returns:
The operation reference object
"""
if '/indexes/' in operation_name:
try:
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.indexes.operations')
except resources.WrongResourceCollectionException:
pass
return resources.REGISTRY.ParseRelativeName(
operation_name, collection='aiplatform.projects.locations.operations')

View File

@@ -0,0 +1,134 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interacting with streaming logs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
from apitools.base.py import encoding
from googlecloudsdk.command_lib.logs import stream
import six
LOG_FORMAT = ('value('
'severity,'
'timestamp.date("%Y-%m-%d %H:%M:%S %z",tz="LOCAL"), '
'task_name,'
'message'
')')
_CONTINUE_INTERVAL = 10
def StreamLogs(name, continue_function, polling_interval, task_name,
allow_multiline):
"""Returns the streaming log of the job by id.
Args:
name: string id of the entity.
continue_function: One-arg function that takes in the number of empty polls
and outputs a boolean to decide if we should keep polling or not. If not
given, keep polling indefinitely.
polling_interval: amount of time to sleep between each poll.
task_name: String name of task.
allow_multiline: Tells us if logs with multiline messages are okay or not.
"""
log_fetcher = stream.LogFetcher(
filters=_LogFilters(name, task_name=task_name),
polling_interval=polling_interval,
continue_interval=_CONTINUE_INTERVAL,
continue_func=continue_function)
return _SplitMultiline(log_fetcher.YieldLogs(), allow_multiline)
def _LogFilters(name, task_name):
"""Returns filters for log fetcher to use.
Args:
name: string id of the entity.
task_name: String name of task.
Returns:
A list of filters to be passed to the logging API.
"""
filters = [
'resource.type="ml_job"', 'resource.labels.job_id="{0}"'.format(name)
]
if task_name:
filters.append('resource.labels.task_name="{0}"'.format(task_name))
return filters
def _SplitMultiline(log_generator, allow_multiline=False):
"""Splits the dict output of logs into multiple lines.
Args:
log_generator: iterator that returns a an ml log in dict format.
allow_multiline: Tells us if logs with multiline messages are okay or not.
Yields:
Single-line ml log dictionaries.
"""
for log in log_generator:
log_dict = _EntryToDict(log)
messages = log_dict['message'].splitlines()
if allow_multiline:
yield log_dict
else:
if not messages:
messages = ['']
for message in messages:
single_line_log = copy.deepcopy(log_dict)
single_line_log['message'] = message
yield single_line_log
def _EntryToDict(log_entry):
"""Converts a log entry to a dictionary."""
output = {}
output[
'severity'] = log_entry.severity.name if log_entry.severity else 'DEFAULT'
output['timestamp'] = log_entry.timestamp
output['task_name'] = _GetTaskName(log_entry)
message = []
if log_entry.jsonPayload is not None:
json_data = _ToDict(log_entry.jsonPayload)
# 'message' contains a free-text message that we want to pull out of the
# JSON.
if 'message' in json_data:
if json_data['message']:
message.append(json_data['message'])
elif log_entry.textPayload is not None:
message.append(six.text_type(log_entry.textPayload))
output['message'] = ''.join(message)
return output
def _GetTaskName(log_entry):
"""Reads the label attributes of the given log entry."""
resource_labels = {} if not log_entry.resource else _ToDict(
log_entry.resource.labels)
return 'unknown_task' if not resource_labels.get(
'task_name') else resource_labels['task_name']
def _ToDict(message):
if not message:
return {}
if isinstance(message, dict):
return message
else:
return encoding.MessageToDict(message)

View File

@@ -0,0 +1,613 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for the model garden command group."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import datetime
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from googlecloudsdk.api_lib.ai import operations
from googlecloudsdk.api_lib.ai.models import client as client_models
from googlecloudsdk.api_lib.monitoring import metric
from googlecloudsdk.api_lib.quotas import quota_info
from googlecloudsdk.command_lib.ai import endpoints_util
from googlecloudsdk.command_lib.ai import models_util
from googlecloudsdk.command_lib.ai import operations_util
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import requests
from googlecloudsdk.core import resources
from googlecloudsdk.core.console import console_io
_MAX_LABEL_VALUE_LENGTH = 63
_ACCELERATOR_TYPE_TO_QUOTA_ID_MAP = {
'NVIDIA_TESLA_P4': 'CustomModelServingP4GPUsPerProjectPerRegion',
'NVIDIA_TESLA_T4': 'CustomModelServingT4GPUsPerProjectPerRegion',
'NVIDIA_L4': 'CustomModelServingL4GPUsPerProjectPerRegion',
'NVIDIA_TESLA_K80': 'CustomModelServingK80GPUsPerProjectPerRegion',
'NVIDIA_TESLA_V100': 'CustomModelServingV100GPUsPerProjectPerRegion',
'NVIDIA_TESLA_P100': 'CustomModelServingP100GPUsPerProjectPerRegion',
'NVIDIA_TESLA_A100': 'CustomModelServingA100GPUsPerProjectPerRegion',
'NVIDIA_A100_80GB': 'CustomModelServingA10080GBGPUsPerProjectPerRegion',
'NVIDIA_H100_80GB': 'CustomModelServingH100GPUsPerProjectPerRegion',
'TPU_V5_LITEPOD': 'CustomModelServingV5ETPUPerProjectPerRegion',
}
_ACCELERATOR_TYPE_TP_QUOTA_METRIC_MAP = {
'NVIDIA_TESLA_P4': 'custom_model_serving_nvidia_p4_gpus',
'NVIDIA_TESLA_T4': 'custom_model_serving_nvidia_t4_gpus',
'NVIDIA_L4': 'custom_model_serving_nvidia_l4_gpus',
'NVIDIA_TESLA_K80': 'custom_model_serving_nvidia_k80_gpus',
'NVIDIA_TESLA_V100': 'custom_model_serving_nvidia_v100_gpus',
'NVIDIA_TESLA_P100': 'custom_model_serving_nvidia_p100_gpus',
'NVIDIA_TESLA_A100': 'custom_model_serving_nvidia_a100_gpus',
'NVIDIA_A100_80GB': 'custom_model_serving_nvidia_a100_80gb_gpus',
'NVIDIA_H100_80GB': 'custom_model_serving_nvidia_h100_gpus',
'TPU_V5_LITEPOD': 'custom_model_serving_tpu_v5e',
}
_TIME_SERIES_FILTER = (
# gcloud-disable-gdu-domain
'metric.type="serviceruntime.googleapis.com/quota/allocation/usage" AND'
' resource.type="consumer_quota" AND'
# gcloud-disable-gdu-domain
' metric.label.quota_metric="aiplatform.googleapis.com/{}"'
' AND resource.label.project_id="{}" AND resource.label.location="{}" AND'
# gcloud-disable-gdu-domain
' resource.label.service="aiplatform.googleapis.com"'
)
def _ParseEndpoint(endpoint_id, location_id):
"""Parses a Vertex Endpoint ID into a endpoint resource object."""
return resources.REGISTRY.Parse(
endpoint_id,
params={
'locationsId': location_id,
'projectsId': properties.VALUES.core.project.GetOrFail,
},
collection='aiplatform.projects.locations.endpoints',
)
def _GetQuotaLimit(region, project, accelerator_type):
"""Gets the quota limit for the accelerator type in the region."""
accelerator_quota = quota_info.GetQuotaInfo(
project,
None,
None,
# gcloud-disable-gdu-domain
'aiplatform.googleapis.com',
_ACCELERATOR_TYPE_TO_QUOTA_ID_MAP[accelerator_type],
)
for region_info in accelerator_quota.dimensionsInfos:
if region_info.applicableLocations[0] == region:
return region_info.details.value or 0
return 0
def _GetQuotaUsage(region, project, accelerator_type):
"""Gets the quota usage for the accelerator type in the region using the monitoring AP."""
# Format the time in RFC3339 UTC Zulu format
current_time_utc = datetime.datetime.now(datetime.timezone.utc)
# Need to go back at least 24 hours to reliably get a data point
twenty_five_hours_ago_time_utc = current_time_utc - datetime.timedelta(
hours=25
)
rfc3339_time = current_time_utc.isoformat(timespec='seconds').replace(
'+00:00', 'Z'
)
rfc3339_time_twenty_five_hours_ago = twenty_five_hours_ago_time_utc.isoformat(
timespec='seconds'
).replace('+00:00', 'Z')
quota_usage_time_series = metric.MetricClient().ListTimeSeriesByProject(
project=project,
aggregation_alignment_period='60s',
aggregation_per_series_aligner=metric.GetMessagesModule().MonitoringProjectsTimeSeriesListRequest.AggregationPerSeriesAlignerValueValuesEnum.ALIGN_NEXT_OLDER,
interval_start_time=rfc3339_time_twenty_five_hours_ago,
interval_end_time=rfc3339_time,
filter_str=_TIME_SERIES_FILTER.format(
_ACCELERATOR_TYPE_TP_QUOTA_METRIC_MAP[accelerator_type],
project,
region,
),
)
try:
current_usage = (
quota_usage_time_series.timeSeries[0].points[0].value.int64Value
)
except IndexError:
# If no data point is found, the usage is 0.
current_usage = 0
return current_usage
def GetCLIEndpointLabelValue(
is_hf_model, publisher_name, model_name='', model_version_name=''
):
if is_hf_model:
return f'hf-{publisher_name}-{model_name}'.replace('.', '_')[
:_MAX_LABEL_VALUE_LENGTH
]
else:
return f'mg-{publisher_name}-{model_version_name}'.replace('.', '_')[
:_MAX_LABEL_VALUE_LENGTH
]
def GetOneClickEndpointLabelValue(
is_hf_model, publisher_name, model_name='', model_version_name=''
):
if is_hf_model:
return f'hf-{publisher_name}-{model_name}'.replace('.', '_')[
:_MAX_LABEL_VALUE_LENGTH
]
else:
return (
f'publishers-{publisher_name}-models-{model_name}-{model_version_name}'
.replace(
'.', '_'
)[
:_MAX_LABEL_VALUE_LENGTH
]
)
def IsHFModelGated(publisher_name, model_name):
"""Checks if the HF model is gated or not by calling HF API."""
hf_response = requests.GetSession().get(
f'https://huggingface.co/api/models/{publisher_name}/{model_name}?blobs=true'
)
if hf_response.status_code != 200:
raise core_exceptions.InternalError(
"Something went wrong when we call HuggingFace's API to get the"
' model metadata. Please try again later.'
)
return bool(hf_response.json()['gated'])
def VerifyHFTokenPermission(hf_token, publisher_name, model_name):
hf_response = requests.GetSession().request(
'GET',
f'https://huggingface.co/api/models/{publisher_name}/{model_name}/auth-check',
headers={'Authorization': f'Bearer {hf_token}'},
)
if hf_response.status_code != 200:
raise core_exceptions.Error(
'The Hugging Face access token is not valid or does not have permission'
' to access the gated model.'
)
return
def GetDeployConfig(args, publisher_model):
"""Returns a best suited deployment configuration for the publisher model."""
try:
multi_deploy = (
publisher_model.supportedActions.multiDeployVertex.multiDeployVertex
)
except AttributeError:
raise core_exceptions.Error(
'Model does not support deployment, please use a deploy-able model'
' instead. You can use the `gcloud ai model-garden models list`'
' command to find out which ones are currently supported by the'
' `deploy` command.'
)
deploy_config = None
if args.machine_type or args.accelerator_type or args.container_image_uri:
for deploy in multi_deploy:
if (
(
args.machine_type
and deploy.dedicatedResources.machineSpec.machineType
!= args.machine_type
)
or (
args.accelerator_type
and str(deploy.dedicatedResources.machineSpec.acceleratorType)
!= args.accelerator_type.upper()
)
or (
args.container_image_uri
and deploy.containerSpec.imageUri != args.container_image_uri
)
):
continue
deploy_config = deploy
break
if not deploy_config:
raise core_exceptions.Error(
'The machine type, accelerator type and/or container image URI is not'
' supported by the model. You can use `gcloud alpha/beta ai'
' model-garden models list-deployment-config` command to find the'
' supported configurations'
)
log.status.Print('Using the selected deployment configuration:')
else:
# Default to use the first config.
deploy_config = multi_deploy[0]
log.status.Print('Using the default deployment configuration:')
machine_spec = deploy_config.dedicatedResources.machineSpec
container_image_uri = deploy_config.containerSpec.imageUri
if machine_spec.machineType:
log.status.Print(f' Machine type: {machine_spec.machineType}')
if machine_spec.acceleratorType:
log.status.Print(f' Accelerator type: {machine_spec.acceleratorType}')
if machine_spec.acceleratorCount:
log.status.Print(f' Accelerator count: {machine_spec.acceleratorCount}')
if container_image_uri:
log.status.Print(f' Container image URI: {container_image_uri}')
return deploy_config
def CheckAcceleratorQuota(
args, machine_type, accelerator_type, accelerator_count
):
"""Checks the accelerator quota for the project and region."""
# In the machine spec, TPUs don't have accelerator type and count, but they
# have machine type.
if machine_type == 'ct5lp-hightpu-1t':
accelerator_type = 'TPU_V5_LITEPOD'
accelerator_count = 1
elif machine_type == 'ct5lp-hightpu-4t':
accelerator_type = 'TPU_V5_LITEPOD'
accelerator_count = 4
project = properties.VALUES.core.project.GetOrFail()
quota_limit = _GetQuotaLimit(args.region, project, accelerator_type)
if quota_limit < accelerator_count:
raise core_exceptions.Error(
'The project does not have enough quota for'
f' {_ACCELERATOR_TYPE_TP_QUOTA_METRIC_MAP[accelerator_type]} in'
f' {args.region} to'
f' deploy the model. The quota limit is {quota_limit} and you are'
f' requesting for {accelerator_count}. Please'
' use a different region or request more quota by following'
' https://cloud.google.com/vertex-ai/docs/quotas#requesting_additional_quota.'
)
current_usage = _GetQuotaUsage(args.region, project, accelerator_type)
if current_usage + accelerator_count > quota_limit:
raise core_exceptions.Error(
'The project does not have enough quota for'
f' {_ACCELERATOR_TYPE_TP_QUOTA_METRIC_MAP[accelerator_type]} in'
f' {args.region} to'
f' deploy the model. The current usage is {current_usage} out of'
f' {quota_limit} and you are'
f' requesting for {accelerator_count}. Please'
' use a different region or request more quota by following'
' https://cloud.google.com/vertex-ai/docs/quotas#requesting_additional_quota.'
)
log.status.Print(
'The project has enough quota. The current usage of quota for'
f' accelerator type {accelerator_type} in region {args.region} is'
f' {current_usage} out of {quota_limit}.'
)
def CreateEndpoint(
endpoint_name,
label_value,
region_ref,
operation_client,
endpoints_client,
):
"""Creates a Vertex endpoint for deployment."""
create_endpoint_op = endpoints_client.CreateBeta(
region_ref,
endpoint_name,
labels=endpoints_client.messages.GoogleCloudAiplatformV1beta1Endpoint.LabelsValue(
additionalProperties=[
endpoints_client.messages.GoogleCloudAiplatformV1beta1Endpoint.LabelsValue.AdditionalProperty(
key='mg-cli-deploy', value=label_value
)
]
),
)
create_endpoint_response_msg = operations_util.WaitForOpMaybe(
operation_client,
create_endpoint_op,
endpoints_util.ParseOperation(create_endpoint_op.name),
)
if create_endpoint_response_msg is None:
raise core_exceptions.InternalError(
'Internal error: Failed to create a Vertex endpoint. Please try again.'
)
response = encoding.MessageToPyValue(create_endpoint_response_msg)
if 'name' not in response:
raise core_exceptions.InternalError(
'Internal error: Failed to create a Vertex endpoint. Please try again.'
)
log.status.Print(
(
'Created Vertex AI endpoint: {}.\nStarting to upload the model'
' to Model Registry.'
).format(response['name'])
)
return response['name'].split('/')[-1]
def UploadModel(
deploy_config,
args,
requires_hf_token,
is_hf_model,
uploaded_model_name,
publisher_name,
publisher_model_name,
):
"""Uploads the Model Garden model to Model Registry."""
container_env_vars, container_args, container_commands = None, None, None
if deploy_config.containerSpec.env:
container_env_vars = {
var.name: var.value for var in deploy_config.containerSpec.env
}
if requires_hf_token and 'HUGGING_FACE_HUB_TOKEN' in container_env_vars:
container_env_vars['HUGGING_FACE_HUB_TOKEN'] = (
args.hugging_face_access_token
)
if deploy_config.containerSpec.args:
container_args = list(deploy_config.containerSpec.args)
if deploy_config.containerSpec.command:
container_commands = list(deploy_config.containerSpec.command)
models_client = client_models.ModelsClient()
upload_model_op = models_client.UploadV1Beta1(
args.CONCEPTS.region.Parse(),
uploaded_model_name, # Re-use endpoint_name as the uploaded model name.
None,
None,
deploy_config.artifactUri,
deploy_config.containerSpec.imageUri,
container_commands,
container_args,
container_env_vars,
[deploy_config.containerSpec.ports[0].containerPort],
None,
deploy_config.containerSpec.predictRoute,
deploy_config.containerSpec.healthRoute,
base_model_source=models_client.messages.GoogleCloudAiplatformV1beta1ModelBaseModelSource(
modelGardenSource=models_client.messages.GoogleCloudAiplatformV1beta1ModelGardenSource(
# The value is consistent with one-click deploy.
publicModelName='publishers/{}/models/{}'.format(
'hf-' + publisher_name if is_hf_model else publisher_name,
publisher_model_name,
)
)
),
)
upload_model_response_msg = operations_util.WaitForOpMaybe(
operations_client=operations.OperationsClient(),
op=upload_model_op,
op_ref=models_util.ParseModelOperation(upload_model_op.name),
)
if upload_model_response_msg is None:
raise core_exceptions.InternalError(
'Internal error: Failed to upload a Model Garden model to Model'
' Registry. Please try again later.'
)
upload_model_response = encoding.MessageToPyValue(upload_model_response_msg)
if 'model' not in upload_model_response:
raise core_exceptions.InternalError(
'Internal error: Failed to upload a Model Garden model to Model'
' Registry. Please try again later.'
)
log.status.Print(
(
'Uploaded model to Model Registry at {}.\nStarting to deploy the'
' model.'
).format(upload_model_response['model'])
)
return upload_model_response['model'].split('/')[-1]
def DeployModel(
args,
deploy_config,
endpoint_id,
endpoint_name,
model_id,
endpoints_client,
operation_client,
):
"""Deploys the Model Registry model to the Vertex endpoint."""
accelerator_type = (
deploy_config.dedicatedResources.machineSpec.acceleratorType
)
accelerator_count = (
deploy_config.dedicatedResources.machineSpec.acceleratorCount
)
accelerator_dict = None
if accelerator_type is not None or accelerator_count is not None:
accelerator_dict = {}
if accelerator_type is not None:
accelerator_dict['type'] = str(accelerator_type).lower().replace('_', '-')
if accelerator_count is not None:
accelerator_dict['count'] = accelerator_count
deploy_model_op = endpoints_client.DeployModelBeta(
_ParseEndpoint(endpoint_id, args.region),
model_id,
args.region,
endpoint_name, # Use the endpoint_name as the deployed model name.
machine_type=deploy_config.dedicatedResources.machineSpec.machineType,
accelerator_dict=accelerator_dict,
enable_access_logging=True,
enable_container_logging=True,
)
operations_util.WaitForOpMaybe(
operation_client,
deploy_model_op,
endpoints_util.ParseOperation(deploy_model_op.name),
asynchronous=True, # Deploy the model asynchronously.
)
deploy_op_id = deploy_model_op.name.split('/')[-1]
print(
'Deploying the model to the endpoint. To check the deployment'
' status, you can try one of the following methods:\n1) Look for'
f' endpoint `{endpoint_name}` at the [Vertex AI] -> [Online'
' prediction] tab in Cloud Console\n2) Use `gcloud ai operations'
f' describe {deploy_op_id} --region={args.region}` to find the status'
' of the deployment long-running operation\n3) Use `gcloud ai'
f' endpoints describe {endpoint_id} --region={args.region}` command'
" to check the endpoint's metadata."
)
def Deploy(
args, machine_spec, endpoint_name, model, operation_client, mg_client
):
"""Deploys the publisher model to a Vertex endpoint."""
try:
if machine_spec is not None:
machine_type = machine_spec.machineType
accelerator_type = machine_spec.acceleratorType
accelerator_count = machine_spec.acceleratorCount
else:
machine_type = None
accelerator_type = None
accelerator_count = None
deploy_op = mg_client.Deploy(
project=properties.VALUES.core.project.GetOrFail(),
location=args.region,
model=model,
accept_eula=args.accept_eula,
accelerator_type=accelerator_type,
accelerator_count=accelerator_count,
machine_type=machine_type,
endpoint_display_name=endpoint_name,
hugging_face_access_token=args.hugging_face_access_token,
spot=args.spot,
reservation_affinity=args.reservation_affinity,
use_dedicated_endpoint=args.use_dedicated_endpoint,
enable_fast_tryout=args.enable_fast_tryout,
container_image_uri=args.container_image_uri,
container_command=args.container_command,
container_args=args.container_args,
container_env_vars=args.container_env_vars,
container_ports=args.container_ports,
container_grpc_ports=args.container_grpc_ports,
container_predict_route=args.container_predict_route,
container_health_route=args.container_health_route,
container_deployment_timeout_seconds=args.container_deployment_timeout_seconds,
container_shared_memory_size_mb=args.container_shared_memory_size_mb,
container_startup_probe_exec=args.container_startup_probe_exec,
container_startup_probe_period_seconds=args.container_startup_probe_period_seconds,
container_startup_probe_timeout_seconds=args.container_startup_probe_timeout_seconds,
container_health_probe_exec=args.container_health_probe_exec,
container_health_probe_period_seconds=args.container_health_probe_period_seconds,
container_health_probe_timeout_seconds=args.container_health_probe_timeout_seconds,
)
except apitools_exceptions.HttpError as e:
# Keep prompting for HF token if the error is due to missing HF token.
if (
e.status_code == 400
and 'provide a valid Hugging Face access token' in e.content
and args.hugging_face_access_token is None
):
while not args.hugging_face_access_token:
args.hugging_face_access_token = console_io.PromptPassword(
'Please enter your Hugging Face read access token: '
)
Deploy(
args,
machine_spec,
endpoint_name,
model,
operation_client,
mg_client,
)
return
elif e.status_code == 403 and 'EULA' in e.content:
log.status.Print(
'The End User License Agreement'
' (EULA) of the model has not been accepted.'
)
publisher, model_id = args.model.split('@')[0].split('/')
try:
args.accept_eula = console_io.PromptContinue(
message=(
'The model can be deployed only if the EULA of the model has'
' been'
' accepted. You can view it at'
f' https://console.cloud.google.com/vertex-ai/publishers/{publisher}/model-garden/{model_id}):'
),
prompt_string='Do you want to accept the EULA?',
default=False,
cancel_on_no=True,
cancel_string='EULA is not accepted.',
throw_if_unattended=True,
)
except console_io.Error:
raise core_exceptions.Error(
'Please accept the EULA using the `--accept-eula` flag.'
)
Deploy(
args,
machine_spec,
endpoint_name,
model,
operation_client,
mg_client,
)
return
else:
raise e
deploy_op_id = deploy_op.name.split('/')[-1]
log.status.Print(
'Deploying the model to the endpoint. To check the deployment'
' status, you can try one of the following methods:\n1) Look for'
f' endpoint `{endpoint_name}` at the [Vertex AI] -> [Online'
' prediction] tab in Cloud Console\n2) Use `gcloud ai operations'
f' describe {deploy_op_id} --region={args.region}` to find the status'
' of the deployment long-running operation\n'
)
operations_util.WaitForOpMaybe(
operation_client,
deploy_op,
ParseOperation(deploy_op.name),
asynchronous=args.asynchronous,
max_wait_ms=3600000, # 60 minutes
)
def ParseOperation(operation_name):
"""Parse operation resource to the operation reference object.
Args:
operation_name: The operation resource to wait on
Returns:
The operation reference object
"""
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.operations',
)

View File

@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform model deployment monitoring jobs commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import io
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
def ParseJobName(name):
return resources.REGISTRY.Parse(
name, collection=constants.MODEL_MONITORING_JOBS_COLLECTION).Name()
def ReadInstanceFromArgs(path):
"""Reads the instance from the given file path ('-' for stdin).
Args:
path: str or None, a path to a file ('-' for stdin) containing the JSON
body.
Returns:
A instance.
Raises:
InvalidInstancesFileError: If the input file is invalid (invalid format or
contains too many/zero instances), or an improper combination of input
files was given.
"""
data = console_io.ReadFromFileOrStdin(path, binary=True)
with io.BytesIO(data) as f:
try:
instance = yaml.load(f)
except ValueError:
raise errors.InvalidInstancesFileError(
'Input instance are not in JSON format. '
'See `gcloud ai model-monitoring-jobs create --help` for details.')
if not isinstance(instance, dict):
raise errors.InvalidInstancesFileError(
'Input instance are not in JSON format. '
'See `gcloud ai model-monitoring-jobs create --help` for details.')
return instance
def ParseMonitoringJobOperation(operation_name):
"""Parse operation relative resource name to the operation reference object.
Args:
operation_name: The operation resource name
Returns:
The operation reference object
"""
if '/modelDeploymentMonitoringJobs/' in operation_name:
try:
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.modelDeploymentMonitoringJobs.operations'
)
except resources.WrongResourceCollectionException:
pass
return resources.REGISTRY.ParseRelativeName(
operation_name, collection='aiplatform.projects.locations.operations')

View File

@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform models commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import resources
def ParseModelOperation(operation_name):
"""Parse operation relative resource name to the operation reference object.
Args:
operation_name: The operation resource name
Returns:
The operation reference object
"""
if '/models/' in operation_name:
try:
return resources.REGISTRY.ParseRelativeName(
operation_name,
collection='aiplatform.projects.locations.models.operations')
except resources.WrongResourceCollectionException:
pass
return resources.REGISTRY.ParseRelativeName(
operation_name, collection='aiplatform.projects.locations.operations')

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform operations commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import log
def WaitForOpMaybe(
operations_client,
op,
op_ref,
asynchronous=False,
log_method=None,
message=None,
kind=None,
max_wait_ms=1800000,
):
"""Waits for an operation if asynchronous flag is off.
Args:
operations_client: api_lib.ai.operations.OperationsClient, the client via
which to poll.
op: Cloud AI Platform operation, the operation to poll.
op_ref: The operation reference to the operation resource. It's the result
by calling resources.REGISTRY.Parse
asynchronous: bool, whether to wait for the operation or return immediately
log_method: Logging method used for synchronous operation. If None, no log
message: str, the message to display while waiting for the operation.
kind: str, the resource kind (instance, cluster, project, etc.), which will
be passed to logging function.
max_wait_ms: int, number of ms to wait before raising WaitException.
Returns:
The result of the operation if asynchronous is true, or the Operation
message otherwise
"""
logging_function = {
'create': log.CreatedResource,
'delete': log.DeletedResource,
'update': log.UpdatedResource,
}
if asynchronous:
if logging_function.get(log_method) is not None:
logging_function[log_method](op.name, kind=kind)
return op
return operations_client.WaitForOperation(
op, op_ref, message=message, max_wait_ms=max_wait_ms
).response

View File

@@ -0,0 +1,178 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags definition for gcloud ai persistent-resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import textwrap
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope.concepts import concepts
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags as shared_flags
from googlecloudsdk.command_lib.ai import region_util
from googlecloudsdk.command_lib.ai.persistent_resources import persistent_resource_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.concepts import concept_parsers
# TODO(b/262780738): Add link to persistent resource spec once spec is
# published to the public.
_PERSISTENT_RESOURCE_CONFIG = base.Argument(
'--config',
help=textwrap.dedent("""\
Path to the Persistent Resource configuration file. This file should be a
YAML document containing a list of `ResourcePool`
If an option is specified both in the configuration file **and** via
command-line arguments, the command-line arguments override the
configuration file. Note that keys with underscore are invalid.
Example(YAML):
resourcePoolSpecs:
machineSpec:
machineType: n1-standard-4
replicaCount: 1"""))
_RESOURCE_POOL_SPEC = base.Argument(
'--resource-pool-spec',
action='append',
type=arg_parsers.ArgDict(
spec={
'replica-count': int,
'min-replica-count': int,
'max-replica-count': int,
'machine-type': str,
'accelerator-type': str,
'accelerator-count': int,
'disk-type': str,
'disk-size': int,
'local-ssd-count': int,
}),
metavar='RESOURCE_POOL_SPEC',
help=textwrap.dedent("""\
Defines a resource pool to be created in the Persistent Resource. You can
include multiple resource pool specs in order to create a Persistent
Resource with multiple resource pools.
The spec can contain the following fields:
*machine-type*::: (Required): The type of the machine.
see https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
for supported types. This field corresponds to the `machineSpec.machineType`
field in `ResourcePool` API message.
*replica-count*::: (Required if autoscaling not enabled) The number of
replicas to use when creating this resource pool. This field
corresponds to the replicaCount field in 'ResourcePool' API message.
*min-replica-count*::: (Optional) The minimum number of replicas that
autoscaling will down-size to for this resource pool. Both
min-replica-count and max-replica-count are required to enable
autoscaling on this resource pool. The value for this parameter must be
at least 1.
*max-replica-count*::: (Optional) The maximum number of replicas that
autoscaling will create for this resource pool. Both min-replica-count
and max-replica-count are required to enable autoscaling on this
resource pool. The maximum value for this parameter is 1000.
*accelerator-type*::: (Optional) The type of GPU to attach to the
machines.
see https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
for more requirements. This field corresponds to the `machineSpec.acceleratorType`
field in `ResourcePool` API message.
*accelerator-count*::: (Required with accelerator-type) The number of GPUs
for each VM in the resource pool to use. The default the value if 1.
This field corresponds to the `machineSpec.acceleratorCount` field in
`ResourcePool` API message.
*disk-type*::: (Optional) The type of disk to use for each machine's boot disk in
the resource pool. The default is `pd-standard`. This field corresponds
to the `diskSpec.bootDiskType` field in `ResourcePool` API message.
*disk-size*::: (Optional) The disk size in Gb for each machine's boot disk in the
resource pool. The default is `100`. This field corresponds to
the `diskSpec.bootDiskSizeGb` field in `ResourcePool` API message.
::::
Example:
--worker-pool-spec=replica-count=1,machine-type=n1-highmem-2
"""))
ENABLE_CUSTOM_SERVICE_ACCOUNT = base.Argument(
'--enable-custom-service-account',
action='store_true',
required=False,
help=textwrap.dedent("""\
Whether or not to use a custom user-managed service account with this
Persistent Resource.
"""))
def AddCreatePersistentResourceFlags(parser):
"""Adds flags related to create a Persistent Resource."""
shared_flags.AddRegionResourceArg(
parser,
'to create a Persistent Resource',
prompt_func=region_util.GetPromptForRegionFunc(
constants.SUPPORTED_TRAINING_REGIONS))
shared_flags.NETWORK.AddToParser(parser)
ENABLE_CUSTOM_SERVICE_ACCOUNT.AddToParser(parser)
# TODO(b/262780738): Unimplemented
# shared_flags.TRAINING_SERVICE_ACCOUNT.AddToParser(parser)
shared_flags.AddKmsKeyResourceArg(parser, 'persistent resource')
labels_util.AddCreateLabelsFlags(parser)
shared_flags.GetDisplayNameArg('Persistent Resource',
required=False).AddToParser(parser)
resource_id_flag = base.Argument(
'--persistent-resource-id',
required=True,
default=None,
help='User-specified ID of the Persistent Resource.')
resource_id_flag.AddToParser(parser)
resource_pool_spec_group = base.ArgumentGroup(
help='resource pool specification.', required=True
)
resource_pool_spec_group.AddArgument(_PERSISTENT_RESOURCE_CONFIG)
resource_pool_spec_group.AddArgument(_RESOURCE_POOL_SPEC)
resource_pool_spec_group.AddToParser(parser)
def AddPersistentResourceResourceArg(
parser, verb, regions=constants.SUPPORTED_TRAINING_REGIONS):
"""Add a resource argument for a Vertex AI Persistent Resource.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
regions: list[str], the list of supported regions.
"""
resource_spec = concepts.ResourceSpec(
resource_collection=persistent_resource_util.PERSISTENT_RESOURCE_COLLECTION,
resource_name='persistent resource',
locationsId=shared_flags.RegionAttributeConfig(
prompt_func=region_util.GetPromptForRegionFunc(regions)),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
disable_auto_completers=False)
concept_parsers.ConceptParser.ForResource(
'persistent_resource',
resource_spec,
'The persistent resource {}.'.format(verb),
required=True).AddToParser(parser)

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform persistent resource commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.util.apis import arg_utils
PERSISTENT_RESOURCE_COLLECTION = 'aiplatform.projects.locations.persistentResources'
def _ConstructSingleResourcePoolSpec(aiplatform_client,
spec):
"""Constructs a single resource pool spec.
Args:
aiplatform_client: The AI Platform API client used.
spec: A dict whose fields represent a resource pool spec.
Returns:
A ResourcePoolSpec message instance for setting a resource pool in a
Persistent Resource
"""
resource_pool = aiplatform_client.GetMessage('ResourcePool')()
machine_spec_msg = aiplatform_client.GetMessage('MachineSpec')
machine_spec = machine_spec_msg(machineType=spec.get('machine-type'))
accelerator_type = spec.get('accelerator-type')
if accelerator_type:
machine_spec.acceleratorType = arg_utils.ChoiceToEnum(
accelerator_type, machine_spec_msg.AcceleratorTypeValueValuesEnum)
machine_spec.acceleratorCount = int(spec.get('accelerator-count', 1))
resource_pool.machineSpec = machine_spec
replica_count = spec.get('replica-count')
if replica_count:
resource_pool.replicaCount = int(replica_count)
min_replica_count = spec.get('min-replica-count')
max_replica_count = spec.get('max-replica-count')
if min_replica_count or max_replica_count:
autoscaling_spec = (
aiplatform_client.GetMessage('ResourcePoolAutoscalingSpec')())
autoscaling_spec.minReplicaCount = int(min_replica_count)
autoscaling_spec.maxReplicaCount = int(max_replica_count)
resource_pool.autoscalingSpec = autoscaling_spec
disk_type = spec.get('disk-type')
disk_size = spec.get('disk-size')
if disk_type:
disk_spec_msg = aiplatform_client.GetMessage('DiskSpec')
disk_spec = disk_spec_msg(bootDiskType=disk_type, bootDiskSizeGb=disk_size)
resource_pool.diskSpec = disk_spec
return resource_pool
def _ConstructResourcePoolSpecs(aiplatform_client, specs, **kwargs):
"""Constructs the resource pool specs for a persistent resource.
Args:
aiplatform_client: The AI Platform API client used.
specs: A list of dict of resource pool specs, supposedly derived from
the gcloud command flags.
**kwargs: The keyword args to pass down to construct each worker pool spec.
Returns:
A list of ResourcePool message instances for creating a Persistent Resource.
"""
resource_pool_specs = []
for spec in specs:
if spec:
resource_pool_specs.append(
_ConstructSingleResourcePoolSpec(aiplatform_client, spec, **kwargs))
else:
resource_pool_specs.append(
aiplatform_client.GetMessage('ResourcePoolSpec')())
return resource_pool_specs
def ConstructResourcePools(
aiplatform_client,
persistent_resource_config=None,
resource_pool_specs=None,
**kwargs
):
"""Constructs the resource pools to be used to create a Persistent Resource.
Resource pools from the config file and arguments will be combined.
Args:
aiplatform_client: The AI Platform API client used.
persistent_resource_config: A Persistent Resource configuration imported
from a YAML config.
resource_pool_specs: A dict of worker pool specification, usually derived
from the gcloud command argument values.
**kwargs: The keyword args to pass to construct the worker pool specs.
Returns:
An array of ResourcePool messages for creating a Persistent Resource.
"""
resource_pools = []
if isinstance(persistent_resource_config.resourcePools, list):
resource_pools = persistent_resource_config.resourcePools
if resource_pool_specs:
resource_pools = resource_pools + _ConstructResourcePoolSpecs(
aiplatform_client, resource_pool_specs, **kwargs)
return resource_pools
def _IsKwargsDefined(key, **kwargs):
return key in kwargs and bool(kwargs.get(key))

View File

@@ -0,0 +1,184 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Validation of the arguments for the persistent-resources command group."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import validation
def ValidateRegion(region):
"""Validate whether the given region is allowed for persistent resources."""
validation.ValidateRegion(
region, available_regions=constants.SUPPORTED_TRAINING_REGIONS)
def ValidateCreateArgs(args, persistent_resource_config, version):
"""Validate the argument values specified in the `create` command."""
if args.resource_pool_spec:
_ValidateResourcePoolSpecArgs(args.resource_pool_spec, version)
if isinstance(persistent_resource_config.resourcePools, list):
_ValidateResourcePoolSpecsFromConfig(
persistent_resource_config.resourcePools, version)
if (not args.resource_pool_spec and
not isinstance(persistent_resource_config.resourcePools, list)):
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'No resource pools specified. At least one resource pool must be '
'provided via a YAML config file (--config) or via the '
'--resource-pool-spec arg.')
def _ValidateResourcePoolSpecArgs(resource_pool_specs, version):
"""Validates the argument values specified via `--resource-pool-spec` flags.
Args:
resource_pool_specs: List[dict], a list of resource pool specs specified via
command line arguments.
version: str, the API version this command will interact with, either GA or
BETA.
"""
if not resource_pool_specs[0]:
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'Empty value is not allowed for the first `--resource-pool-spec` flag.')
_ValidateHardwareInResourcePoolSpecArgs(resource_pool_specs, version)
def _ValidateHardwareInResourcePoolSpecArgs(resource_pool_specs, version):
"""Validates the hardware related fields specified in `--resource-pool-spec` flags.
Args:
resource_pool_specs: List[dict], a list of resource pool specs specified via
command line arguments.
version: str, the API version this command will interact with, either GA or
BETA.
"""
for spec in resource_pool_specs:
if spec:
if 'machine-type' not in spec:
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'Key [machine-type] required in dict arg but not provided.')
if ('min-replica-count' in spec) and ('max-replica-count' not in spec):
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'Key [max-replica-count] required in dict arg when key '
'[min-replica-count] is provided.')
if ('max-replica-count' in spec) and ('min-replica-count' not in spec):
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'Key [min-replica-count] required in dict arg when key '
'[max-replica-count] is provided.')
# Require replica count if autoscaling is not enabled on the resource pool
if ('replica-count' not in spec) and ('min-replica-count' not in spec):
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'Key [replica-count] required in dict arg but not provided.')
if ('accelerator-count' in spec) != ('accelerator-type' in spec):
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
'Key [accelerator-type] and [accelerator-count] are required to ' +
'use accelerators.')
accelerator_type = spec.get('accelerator-type', None)
if accelerator_type:
type_enum = api_util.GetMessage(
'MachineSpec', version).AcceleratorTypeValueValuesEnum
valid_types = [
type for type in type_enum.names()
if type.startswith('NVIDIA')
]
if accelerator_type not in valid_types:
raise exceptions.InvalidArgumentException(
'--resource-pool-spec',
('Found invalid value of [accelerator-type]: {actual}. '
'Available values are [{expected}].').format(
actual=accelerator_type,
expected=', '.join(v for v in sorted(valid_types))))
def _ValidateResourcePoolSpecsFromConfig(resource_pools, version):
"""Validate ResourcePoolSpec message instances imported from the config file."""
if not resource_pools:
raise exceptions.InvalidArgumentException(
'--config',
'At least one [resourcePools] required in but not provided in config.')
for spec in resource_pools:
if not spec.machineSpec:
raise exceptions.InvalidArgumentException(
'--config',
'Field [machineSpec] required in but not provided in config.')
if not spec.machineSpec.machineType:
raise exceptions.InvalidArgumentException(
'--config',
'Field [machineType] required in but not provided in config.')
if (not spec.replicaCount) and (not spec.autoscalingSpec):
raise exceptions.InvalidArgumentException(
'--config',
'Field [replicaCount] required in but not provided in config.')
if (spec.autoscalingSpec) and (not spec.autoscalingSpec.minReplicaCount):
raise exceptions.InvalidArgumentException(
'--config',
'Field [minReplicaCount] required when using autoscaling')
if (spec.autoscalingSpec) and (not spec.autoscalingSpec.maxReplicaCount):
raise exceptions.InvalidArgumentException(
'--config',
'Field [maxReplicaCount] required when using autoscaling')
if (spec.machineSpec.acceleratorCount and
not spec.machineSpec.acceleratorType):
raise exceptions.InvalidArgumentException(
'--config',
'Field [acceleratorType] required as [acceleratorCount] is specified'
'in config.')
if spec.diskSpec and (spec.diskSpec.bootDiskSizeGb and
not spec.diskSpec.bootDiskType):
raise exceptions.InvalidArgumentException(
'--config',
'Field [bootDiskType] required as [bootDiskSizeGb] is specified'
'in config.')
if spec.machineSpec.acceleratorType:
accelerator_type = str(spec.machineSpec.acceleratorType.name)
type_enum = api_util.GetMessage(
'MachineSpec', version).AcceleratorTypeValueValuesEnum
valid_types = [
type for type in type_enum.names()
if type.startswith('NVIDIA')
]
if accelerator_type not in valid_types:
raise exceptions.InvalidArgumentException(
'--config',
('Found invalid value of [acceleratorType]: {actual}. '
'Available values are [{expected}].').format(
actual=accelerator_type,
expected=', '.join(v for v in sorted(valid_types))))

View File

@@ -0,0 +1,157 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for handling region flag."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.console import console_io
def _IsDefaultUniverse():
return (
properties.VALUES.core.universe_domain.Get()
== properties.VALUES.core.universe_domain.default
)
def _HandleNonDefaultUniverseRegion():
"""Handles region selection for non-default universes."""
non_default_universe_regions = constants.NON_DEFAULT_UNIVERSE_REGIONS
if not non_default_universe_regions:
return None
if len(non_default_universe_regions) == 1:
return non_default_universe_regions[0]
# Prompt only if console is available and there are multiple non-default
# regions
if console_io.CanPrompt():
all_regions = list(non_default_universe_regions)
idx = console_io.PromptChoice(
all_regions, message='Please specify a region:\n', cancel_option=True)
return all_regions[idx]
# When cannot prompt, returns the first region in the list as a default
# choice.
return None
def GetPromptForRegionFunc(available_regions=constants.SUPPORTED_REGION):
"""Returns a no argument function that prompts available regions and catches the user selection."""
return lambda: PromptForRegion(available_regions)
def PromptForRegion(available_regions=constants.SUPPORTED_REGION):
"""Prompt for region from list of available regions.
This method is referenced by the declaritive iam commands as a fallthrough
for getting the region.
Args:
available_regions: list of the available regions to choose from
Returns:
The region specified by the user, str, or None if not in GDU or cannot
prompt.
"""
if not _IsDefaultUniverse():
return _HandleNonDefaultUniverseRegion()
if console_io.CanPrompt():
all_regions = list(available_regions)
idx = console_io.PromptChoice(
all_regions, message='Please specify a region:\n', cancel_option=True)
region = all_regions[idx]
log.status.Print('To make this the default region, run '
'`gcloud config set ai/region {}`.\n'.format(region))
return region
def PromptForOpRegion():
"""Prompt for region from list of online prediction available regions.
This method is referenced by the declaritive iam commands as a fallthrough
for getting the region.
Returns:
The region specified by the user, str, or None if not in GDU or cannot
prompt.
"""
if not _IsDefaultUniverse():
return _HandleNonDefaultUniverseRegion()
if console_io.CanPrompt():
all_regions = list(constants.SUPPORTED_OP_REGIONS)
idx = console_io.PromptChoice(
all_regions, message='Please specify a region:\n', cancel_option=True)
region = all_regions[idx]
log.status.Print('To make this the default region, run '
'`gcloud config set ai/region {}`.\n'.format(region))
return region
def PromptForDeploymentResourcePoolSupportedRegion():
"""Prompt for region from list of deployment resource pool available regions.
This method is referenced by the declaritive iam commands as a fallthrough
for getting the region.
Returns:
The region specified by the user, str, or None if not in GDU or cannot
prompt.
"""
if not _IsDefaultUniverse():
return _HandleNonDefaultUniverseRegion()
if console_io.CanPrompt():
all_regions = list(constants.SUPPORTED_DEPLOYMENT_RESOURCE_POOL_REGIONS)
idx = console_io.PromptChoice(
all_regions, message='Please specify a region:\n', cancel_option=True)
region = all_regions[idx]
log.status.Print('To make this the default region, run '
'`gcloud config set ai/region {}`.\n'.format(region))
return region
def GetRegion(args, prompt_func=PromptForRegion):
"""Gets the region and prompt for region if not provided.
Region is decided in the following order:
- region argument;
- ai/region gcloud config;
- prompt user input (only in GDU).
Args:
args: Namespace, The args namespace.
prompt_func: Function, To prompt for region from list of available regions.
Returns:
A str representing region.
"""
if getattr(args, 'region', None):
return args.region
if properties.VALUES.ai.region.IsExplicitlySet():
return properties.VALUES.ai.region.Get()
region = prompt_func()
if region:
return region
# In unit test, it's not allowed to prompt for asking the choices. Raising the
# error immediately.
raise exceptions.RequiredArgumentException('--region', 'Region is required')

View File

@@ -0,0 +1,18 @@
operation:
name: operation
collection: aiplatform.projects.locations.operations
attributes:
- *location
- parameter_name: operationsId
attribute_name: operation
help: The name of Vertex AI operation.
model:
name: model
collection: aiplatform.projects.locations.models
attributes:
- *location
- &model
parameter_name: modelsId
attribute_name: model
help: The name of Vertex AI model.

View File

@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags definition specifically for gcloud ai ray job."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import textwrap
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope.concepts import concepts
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags as shared_flags
from googlecloudsdk.command_lib.ai import region_util
from googlecloudsdk.command_lib.ai.serverless_ray_jobs import serverless_ray_jobs_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.concepts import concept_parsers
_ENTRYPOINT_FILE_URI = base.Argument(
'--entrypoint',
metavar='ENTRYPOINT_FILE_URI',
required=True,
help='The Ray job entrypoint Python file Google Cloud Storage URI.',
)
_ENTRYPOINT_JOB_FILE_ARGS = base.Argument(
'--entrypoint-file-args',
metavar='ARG',
type=arg_parsers.ArgList(),
action=arg_parsers.UpdateAction,
help=(
'Comma-separated arguments passed to Ray job python file. e.g.'
' --entrypoint-file-args=arg1,arg2'
),
)
_ARCHIVE_URIS = base.Argument(
'--archive-uris',
metavar='ARG',
hidden=True,
type=arg_parsers.ArgList(),
action=arg_parsers.UpdateAction,
help=(
'Comma-separated archive URIs that will be copy to the Ray nodes. e.g.'
' --archive-uris=gs://test-bucket/test.tar.gz,gs://test-bucket/test2.tar.gz'
),
)
_CONTAINER_IMAGE_URI = base.Argument(
'--container-image-uri',
metavar='CONTAINER_IMAGE_URI',
help='The container image URI to use for the Ray worker node.',
)
_RESOURCE_SPEC = base.Argument(
'--resource-spec',
type=arg_parsers.ArgDict(
spec={
'resource-unit': int,
'disk-size': int,
'max-node-count': int,
}
),
metavar='RESOURCE_SPEC',
help=textwrap.dedent("""\
Define the worker pool resource spec for the serverless ray job.
The spec can contain the following fields:
*resource-unit*::: Optional. Default to 1. Define how many compute resources(CPU, memory) on each worker node. By default we are using machine e2-standard series, and each resource unit allocates 4 vCPUs and 16GB memory. The resource-unit value can only be 1,2,4,8.
*disk-size*::: Optional, default to 100. Disk size in GB on one worker node.
*max-node-count*::: Optional, default to 2000. The max number of worker nodes this job can occupy while running.
::::
Example:
--resource-spec=resource-unit=2,disk-size=100,max-node-count=10
"""),
)
_SERVERLESS_RAY_JOB_SERVICE_ACCOUNT = base.Argument(
'--service-account',
metavar='SERVICE_ACCOUNT',
hidden=True,
help=(
'The service account to use for the Ray job. If not specified, the'
' default service account is used.'
),
)
def AddCreateServerlessRayJobFlags(parser):
"""Adds flags related to create a serverless ray job."""
shared_flags.AddRegionResourceArg(
parser,
'to create a serverless ray job',
prompt_func=region_util.GetPromptForRegionFunc(
constants.SUPPORTED_TRAINING_REGIONS
),
)
shared_flags.GetDisplayNameArg('serverless ray job').AddToParser(parser)
labels_util.AddCreateLabelsFlags(parser)
_SERVERLESS_RAY_JOB_SERVICE_ACCOUNT.AddToParser(parser)
_ENTRYPOINT_FILE_URI.AddToParser(parser)
_RESOURCE_SPEC.AddToParser(parser)
_ARCHIVE_URIS.AddToParser(parser)
_ENTRYPOINT_JOB_FILE_ARGS.AddToParser(parser)
_CONTAINER_IMAGE_URI.AddToParser(parser)
def AddServerlessRayJobResourceArg(
parser, verb, regions=constants.SUPPORTED_TRAINING_REGIONS
):
"""Add a resource argument for a Vertex AI serverless ray job.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the parser for the command.
verb: str, the verb to describe the job resource, such as 'to update'.
regions: list[str], the list of supported regions.
"""
resource_spec = concepts.ResourceSpec(
resource_collection=serverless_ray_jobs_util.SERVERLESS_RAY_JOB_COLLECTION,
resource_name='serverless ray job',
locationsId=shared_flags.RegionAttributeConfig(
prompt_func=region_util.GetPromptForRegionFunc(regions)
),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
disable_auto_completers=False,
)
concept_parsers.ConceptParser.ForResource(
'serverless_ray_job',
resource_spec,
'The serverless ray job {}.'.format(verb),
required=True,
).AddToParser(parser)

View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform serverless ray jobs commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
SERVERLESS_RAY_JOB_COLLECTION = (
'aiplatform.projects.locations.serverlessRayJobs'
)
def _ConstructResourceSpecs(aiplatform_client, resource_spec):
"""Constructs the specification of a Ray worker nodepool.
Args:
aiplatform_client: The AI Platform API client used.
resource_spec: A dict whose fields represent the resource spec.
Returns:
A ResoueceSpec message instance for nodepool resource spec for the
serverless ray job.
"""
resource_specs = []
spec = aiplatform_client.GetMessage('ServerlessRayJobSpecResourceSpec')()
resource_spec_dict = resource_spec
if resource_spec_dict.get('disk-size'):
spec.disk = aiplatform_client.GetMessage(
'ServerlessRayJobSpecResourceSpecDisk'
)(diskSizeGb=resource_spec_dict.get('disk-size'))
if resource_spec_dict.get('resource-unit'):
spec.resourceUnit = resource_spec_dict.get('resource-unit')
if resource_spec_dict.get('max-node-count'):
spec.maxNodeCount = resource_spec_dict.get('max-node-count')
print('resource_spec: {}'.format(spec))
resource_specs.append(spec)
return resource_specs
def ConstructServerlessRayJobSpec(
aiplatform_client,
main_python_file_uri=None,
entrypoint_file_args=None,
archive_uris=None,
service_account=None,
container_image_uri=None,
resource_spec=None,
):
"""Constructs the spec of a serverless ray job to be used in job creation request.
Args:
aiplatform_client: The AI Platform API client used.
main_python_file_uri: The main python file uri of the serverless ray job.
entrypoint_file_args: The args to pass into the serverless ray job.
archive_uris: The uris of the archives to be extracted and copy to Ray
worker nodes.
service_account: The service account to run the serverless ray job as.
container_image_uri: The container image uri to run the serverless ray job.
resource_spec: The resource spec of the nodepool for the serverless ray job.
Returns:
A ServerlessRayJobSpec message instance for creating a serverless ray job.
"""
job_spec_message = aiplatform_client.GetMessage('ServerlessRayJobSpec')
job_spec = job_spec_message(mainPythonFileUri=main_python_file_uri)
if service_account is not None:
job_spec.serviceAccount = service_account
if archive_uris:
job_spec.archiveUris = archive_uris
if entrypoint_file_args:
job_spec.args = entrypoint_file_args
if resource_spec:
job_spec.resourceSpecs = _ConstructResourceSpecs(
aiplatform_client, resource_spec
)
if container_image_uri:
runtime_env = aiplatform_client.GetMessage(
'ServerlessRayJobSpecRuntimeEnv'
)()
runtime_env_container = aiplatform_client.GetMessage(
'ServerlessRayJobSpecRuntimeEnvContainer'
)(imageUri=container_image_uri)
runtime_env.container = runtime_env_container
job_spec.runtimeEnv = runtime_env
return job_spec
def _IsKwargsDefined(key, **kwargs):
return key in kwargs and bool(kwargs.get(key))

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ai.tensorboard_time_series import client
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.core import resources
def ParseTensorboardOperation(operation_name):
"""Parse operation relative resource name to the operation reference object.
Args:
operation_name: The operation resource name
Returns:
The operation reference object
"""
collection = 'aiplatform.projects.locations'
if '/tensorboards/' in operation_name:
collection += '.tensorboards'
if '/experiments/' in operation_name:
collection += '.experiments'
if '/runs/' in operation_name:
collection += '.runs'
collection += '.operations'
try:
return resources.REGISTRY.ParseRelativeName(
operation_name, collection=collection)
except resources.WrongResourceCollectionException:
return resources.REGISTRY.ParseRelativeName(
operation_name, collection='aiplatform.projects.locations.operations')
_TYPE_CHOICES = {
'SCALAR': (
'scalar',
'Used for tensorboard-time-series that is a list of scalars. E.g. '
'accuracy of a model over epochs/time.'
),
'TENSOR': (
'tensor',
'Used for tensorboard-time-series that is a list of tensors. E.g. '
'histograms of weights of layer in a model over epoch/time.'
),
'BLOB_SEQUENCE': (
'blob-sequence',
'Used for tensorboard-time-series that is a list of blob sequences. '
'E.g. set of sample images with labels over epochs/time.'
),
}
def GetTensorboardTimeSeriesTypeArg(noun):
return arg_utils.ChoiceEnumMapper(
'--type',
client.GetMessagesModule(
).GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum,
required=True,
custom_mappings=_TYPE_CHOICES,
help_str='Value type of the {noun}.'.format(noun=noun),
default=None)

View File

@@ -0,0 +1,271 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for validating parameters."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.ai import constants
def ValidateDisplayName(display_name):
"""Validates the display name."""
if display_name is not None and not display_name:
raise exceptions.InvalidArgumentException(
'--display-name', 'Display name can not be empty.'
)
def ValidateRegion(region, available_regions=constants.SUPPORTED_REGION):
"""Validates whether a given region is among the available ones."""
if region not in available_regions:
raise exceptions.InvalidArgumentException(
'region',
'Available values are [{}], but found [{}].'.format(
', '.join(available_regions), region
),
)
def ValidateGpuPartitionSize(gpu_partition_size):
"""Validates the gpu partition size."""
if gpu_partition_size is not None and not isinstance(gpu_partition_size, str):
raise exceptions.InvalidArgumentException(
'--gpu-partition-size',
'Required string, but found [{}].'.format(gpu_partition_size),
)
def GetAndValidateKmsKey(args):
"""Parse CMEK resource arg, and check if the arg was partially specified."""
if hasattr(args.CONCEPTS, 'kms_key'):
kms_ref = args.CONCEPTS.kms_key.Parse()
if kms_ref:
return kms_ref.RelativeName()
else:
for keyword in ['kms_key', 'kms_keyring', 'kms_location', 'kms_project']:
if getattr(args, keyword, None):
raise exceptions.InvalidArgumentException(
'--kms-key', 'Encryption key not fully specified.'
)
def ValidateAutoscalingMetricSpecs(specs):
"""Value validation for autoscaling metric specs target name and value."""
if specs is None:
return
for key, value in specs.items():
if key not in constants.OP_AUTOSCALING_METRIC_NAME_MAPPER:
raise exceptions.InvalidArgumentException(
'--autoscaling-metric-specs',
"""Autoscaling metric name can only be one of the following: {}.""".format(
', '.join([
"'{}'".format(c)
for c in sorted(
constants.OP_AUTOSCALING_METRIC_NAME_MAPPER.keys()
)
])
),
)
if key == 'request-counts-per-minute':
if value <= 0:
raise exceptions.InvalidArgumentException(
'--autoscaling-metric-specs',
'Metric target for request-counts-per-minute must be a positive'
' value.',
)
elif value <= 0 or value > 100:
raise exceptions.InvalidArgumentException(
'--autoscaling-metric-specs',
'Metric target value {} for {} is not between 0 and 100.'.format(
value, key
),
)
def ValidateRequiredReplicaCount(required_replica_count, min_replica_count):
"""Value validation for required replica count."""
if required_replica_count is not None:
min_replica_count = min_replica_count or 1
if required_replica_count > min_replica_count:
raise exceptions.InvalidArgumentException(
'--required-replica-count',
'Value must be less than or equal to min-replica-count.'
)
def ValidateScaleToZeroArgs(
min_replica_count=None,
initial_replica_count=None,
max_replica_count=None,
min_scaleup_period=None,
idle_scaledown_period=None,
):
"""Value validation for scale-to-zero args."""
# Validation for initial replica count.
if initial_replica_count is not None:
if min_replica_count is None:
raise exceptions.InvalidArgumentException(
'--initial-replica-count',
"""Cannot set initial-replica-count without explicitly setting
min-replica-count to 0 to enable scale-to-zero.""",
)
if min_replica_count > 0:
raise exceptions.InvalidArgumentException(
'--initial-replica-count',
"""Cannot set initial-replica-count when min-replica-count > 0 as
scale-to-zero will not be enabled.""",
)
if (
max_replica_count is not None
and max_replica_count < initial_replica_count
):
raise exceptions.InvalidArgumentException(
'--initial-replica-count',
"""Initial-replica-count must be smaller than max replica count.""",
)
# Validation for STZConfig args with min replica count > 0.
if min_scaleup_period is not None:
if min_replica_count is None:
raise exceptions.InvalidArgumentException(
'--min-scaleup-period',
"""Cannot set min-scaleup-period without explicitly setting
min-replica-count to 0 to enable scale-to-zero.""",
)
if min_replica_count > 0:
raise exceptions.InvalidArgumentException(
'--min-scaleup-period',
"""Cannot set min-scaleup-period when min-replica-count > 0 as
scale-to-zero will not be enabled.""",
)
if idle_scaledown_period is not None:
if min_replica_count is None:
raise exceptions.InvalidArgumentException(
'--idle-scaledown-period',
"""Cannot set idle-scaledown-period without explicitly setting
min-replica-count to 0 to enable scale-to-zero.""",
)
if min_replica_count > 0:
raise exceptions.InvalidArgumentException(
'--idle-scaledown-period',
"""Cannot set idle-scaledown-period when min-replica-count > 0 as
scale-to-zero will not be enabled.""",
)
def ValidateSharedResourceArgs(
shared_resources_ref=None,
machine_type=None,
accelerator_dict=None,
min_replica_count=None,
max_replica_count=None,
required_replica_count=None,
autoscaling_metric_specs=None,
):
"""Value validation for dedicated resource args while making a shared resource command call.
Args:
shared_resources_ref: str or None, the shared deployment resource pool
full name the model should use, formatted as the full URI
machine_type: str or None, the type of the machine to serve the model.
accelerator_dict: dict or None, the accelerator attached to the deployed
model from args.
min_replica_count: int or None, the minimum number of replicas the
deployed model will be always deployed on.
max_replica_count: int or None, the maximum number of replicas the
deployed model may be deployed on.
required_replica_count: int or None, the required number of replicas the
deployed model will be considered successfully deployed.
autoscaling_metric_specs: dict or None, the metric specification that
defines the target resource utilization for calculating the desired
replica count.
"""
if shared_resources_ref is None:
return
if machine_type is not None:
raise exceptions.InvalidArgumentException(
'--machine-type',
"""Cannot use
machine type and shared resources in the same command.""",
)
if accelerator_dict is not None:
raise exceptions.InvalidArgumentException(
'--accelerator',
"""Cannot
use accelerator and shared resources in the same command.""",
)
if min_replica_count is not None:
raise exceptions.InvalidArgumentException(
'--min-replica-count',
"""Cannot
use min replica count and shared resources in the same command.""",
)
if max_replica_count is not None:
raise exceptions.InvalidArgumentException(
'--max-replica-count',
"""Cannot
use max replica count and shared resources in the same command.""",
)
if required_replica_count is not None:
raise exceptions.InvalidArgumentException(
'--required-replica-count',
"""Cannot
use required replica count and shared resources in the same command.""",
)
if autoscaling_metric_specs is not None:
raise exceptions.InvalidArgumentException(
'--autoscaling-metric-specs',
"""Cannot use autoscaling metric specs
and shared resources in the same command.""",
)
def ValidateEndpointArgs(network=None, public_endpoint_enabled=None):
"""Validates the network and public_endpoint_enabled."""
if network is not None and public_endpoint_enabled:
raise exceptions.InvalidArgumentException(
'Please either set --network for private endpoint, or set'
' --public-endpoint-enabled',
'for public enpdoint.',
)
def ValidateModelGardenModelArgs(args):
"""Validates the model garden model args."""
if args.model is not None and not args.model:
raise exceptions.InvalidArgumentException(
'--model',
'Model name should not be empty.',
)
if (
len(args.model.split('/')) != 2
or len(args.model.split('/')[1].split('@')) > 2
):
raise exceptions.InvalidArgumentException(
'--model',
'Model name should be in the format of Model Garden, e.g.'
' `{publisher_name}/{model_name}@{model_version_name}, e.g.'
' `google/gemma2@gemma-2-2b` or in the format of Hugging Face'
' convention, e.g. `meta-llama/Meta-Llama-3-8B`. You can use the'
' `gcloud ai model-garden models list` command to find supported'
' models.',
)