feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
@base.ReleaseTracks(base.ReleaseTrack.ALPHA,
base.ReleaseTrack.GA,
base.ReleaseTrack.BETA)
class MlEngine(base.Group):
"""Manage AI Platform jobs and models.
The {command} command group lets you manage AI Platform jobs and
training models.
AI Platform is a managed service that enables you to easily build
machine
learning models, that work on any type of data, of any size. Create your model
with the powerful TensorFlow framework that powers many Google products, from
Google Photos to Google Cloud Speech.
More information on AI Platform can be found here:
https://cloud.google.com/ml
and detailed documentation can be found here:
https://cloud.google.com/ml/docs/
"""
category = base.AI_AND_MACHINE_LEARNING_CATEGORY
def Filter(self, context, args):
# TODO(b/190522169): Determine if command group works with project number
base.RequireProjectID(args)
del context, args
base.DisableUserProjectQuota()
resources.REGISTRY.RegisterApiByName('ml', 'v1')

View File

@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform explain command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import predict
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import predict_utilities
from googlecloudsdk.command_lib.ml_engine import region_util
INPUT_INSTANCES_LIMIT = 100
DETAILED_HELP = {
'EXAMPLES':
"""\
To get explanations for an AI Platform version model with the
version 'version' and with the name 'model-name', run:
$ {command} explain --model=model-name --version=version \
--json-instances=instances.json
""",
}
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class Explain(base.Command):
"""Run AI Platform explanation.
`{command}` sends an explain request to AI Platform for the given
instances. This command will read up to 100 instances, though the service
itself will accept instances up to the payload limit size (currently,
1.5MB).
"""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
"""Register flags for this command."""
parser.add_argument('--model', required=True, help='Name of the model.')
flags.GetRegionArg(include_global=True).AddToParser(parser)
parser.add_argument(
'--version',
help="""\
Model version to be used.
If unspecified, the default version of the model will be used. To list model
versions run
$ {parent_command} versions list
""")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--json-request',
help="""\
Path to a local file containing the body of JSON request.
An example of a JSON request:
{
"instances": [
{"x": [1, 2], "y": [3, 4]},
{"x": [-1, -2], "y": [-3, -4]}
]
}
This flag accepts "-" for stdin.
""")
group.add_argument(
'--json-instances',
help="""\
Path to a local file from which instances are read.
Instances are in JSON format; newline delimited.
An example of the JSON instances file:
{"images": [0.0, ..., 0.1], "key": 3}
{"images": [0.0, ..., 0.1], "key": 2}
...
This flag accepts "-" for stdin.
""")
group.add_argument(
'--text-instances',
help="""\
Path to a local file from which instances are read.
Instances are in UTF-8 encoded text format; newline delimited.
An example of the text instances file:
107,4.9,2.5,4.5,1.7
100,5.7,2.8,4.1,1.3
...
This flag accepts "-" for stdin.
""")
def Run(self, args):
"""This is what gets called when the user runs this command.
Args:
args: an argparse namespace. All the arguments that were provided to this
command invocation.
Returns:
Some value that we want to have printed later.
"""
instances = predict_utilities.ReadInstancesFromArgs(
args.json_request,
args.json_instances,
args.text_instances,
limit=INPUT_INSTANCES_LIMIT)
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
args.model, args.version)
results = predict.Explain(model_or_version_ref, instances)
if not args.IsSpecified('format'):
# default format is based on the response.
args.format = predict_utilities.GetDefaultFormat(
results.get('predictions'))
return results

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform jobs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
class Jobs(base.Group):
"""AI Platform Jobs commands."""
pass

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs cancel command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
DETAILED_HELP = {
'EXAMPLES':
"""\
To cancel a running AI Platform job named ``my-job'', run:
$ {command} my-job
"""
}
def _AddCancelArgs(parser):
flags.JOB_NAME.AddToParser(parser)
class Cancel(base.SilentCommand):
"""Cancel a running AI Platform job."""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
_AddCancelArgs(parser)
def Run(self, args):
return jobs_util.Cancel(jobs.JobsClient(), args.job)
_DETAILED_HELP = {
'DESCRIPTION':
"""\
*{command}* cancels a running AI Platform job. If the job is already
finished, the command will not perform an operation and exit successfully.
"""
}
Cancel.detailed_help = _DETAILED_HELP

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs describe command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
from googlecloudsdk.core import log
DETAILED_HELP = {
'EXAMPLES':
"""\
To describe the AI Platform job named ``my-job'', run:
{command} my-job
"""
}
def _AddDescribeArgs(parser):
flags.JOB_NAME.AddToParser(parser)
flags.GetSummarizeFlag().AddToParser(parser)
class Describe(base.DescribeCommand):
"""Describe an AI Platform job."""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
job = jobs_util.Describe(jobs.JobsClient(), args.job)
self.job = job # Hack to make the Epilog method work
if args.summarize:
if args.format:
log.warning('--format is ignored when --summarize is present')
args.format = jobs_util.GetSummaryFormat(job)
return job
def Epilog(self, resources_were_displayed):
if resources_were_displayed:
jobs_util.PrintDescribeFollowUp(self.job.jobId)

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs list command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import jobs_util
_DEFAULT_FORMAT = """
table(
jobId.basename(),
state:label=STATUS,
createTime.date(tz=LOCAL):label=CREATED
)
"""
DETAILED_HELP = {
'EXAMPLES':
"""\
To list the existing AI Platform jobs, run:
$ {command}
"""
}
class List(base.ListCommand):
"""List existing AI Platform jobs."""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
parser.display_info.AddFormat(_DEFAULT_FORMAT)
def Run(self, args):
return jobs_util.List(jobs.JobsClient())

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs stream-logs command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
from googlecloudsdk.command_lib.ml_engine import log_utils
from googlecloudsdk.core import properties
DETAILED_HELP = {
'EXAMPLES':
"""\
To show the logs from running the AI Platform job ``my-job'', run:
$ {command} my-job
"""
}
def _AddStreamLogsArgs(parser):
flags.JOB_NAME.AddToParser(parser)
flags.POLLING_INTERVAL.AddToParser(parser)
flags.ALLOW_MULTILINE_LOGS.AddToParser(parser)
flags.TASK_NAME.AddToParser(parser)
class StreamLogs(base.Command):
"""Show logs from a running AI Platform job."""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
_AddStreamLogsArgs(parser)
parser.display_info.AddFormat(log_utils.LOG_FORMAT)
def Run(self, args):
"""Run the stream-logs command."""
return jobs_util.StreamLogs(
args.job, args.task_name,
properties.VALUES.ml_engine.polling_interval.GetInt(),
args.allow_multiline_logs)

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform jobs submit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
class Submit(base.Group):
"""AI Platform Jobs submit commands."""
pass

View File

@@ -0,0 +1,179 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs submit batch prediction command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
from googlecloudsdk.command_lib.util.args import labels_util
def _AddAcceleratorFlags(parser):
"""Add arguments for accelerator config."""
accelerator_config_group = base.ArgumentGroup(
help='Accelerator Configuration.')
accelerator_config_group.AddArgument(base.Argument(
'--accelerator-count',
required=True,
default=1,
type=arg_parsers.BoundedInt(lower_bound=1),
help=('The number of accelerators to attach to the machines.'
' Must be >= 1.')))
accelerator_config_group.AddArgument(
jobs_util.AcceleratorFlagMap().choice_arg)
accelerator_config_group.AddToParser(parser)
def _AddSubmitPredictionArgs(parser):
"""Add arguments for `jobs submit prediction` command."""
parser.add_argument('job', help='Name of the batch prediction job.')
model_group = parser.add_mutually_exclusive_group(required=True)
model_group.add_argument(
'--model-dir',
help=('Cloud Storage location where '
'the model files are located.'))
model_group.add_argument(
'--model', help='Name of the model to use for prediction.')
parser.add_argument(
'--version',
help="""\
Model version to be used.
This flag may only be given if --model is specified. If unspecified, the default
version of the model will be used. To list versions for a model, run
$ gcloud ai-platform versions list
""")
# input location is a repeated field.
parser.add_argument(
'--input-paths',
type=arg_parsers.ArgList(min_length=1),
required=True,
metavar='INPUT_PATH',
help="""\
Cloud Storage paths to the instances to run prediction on.
Wildcards (```*```) accepted at the *end* of a path. More than one path can be
specified if multiple file patterns are needed. For example,
gs://my-bucket/instances*,gs://my-bucket/other-instances1
will match any objects whose names start with `instances` in `my-bucket` as well
as the `other-instances1` bucket, while
gs://my-bucket/instance-dir/*
will match any objects in the `instance-dir` "directory" (since directories
aren't a first-class Cloud Storage concept) of `my-bucket`.
""")
jobs_util.DataFormatFlagMap().choice_arg.AddToParser(parser)
parser.add_argument(
'--output-path', required=True,
help='Cloud Storage path to which to save the output. '
'Example: gs://my-bucket/output.')
parser.add_argument(
'--region',
required=True,
help='The Compute Engine region to run the job in.')
parser.add_argument(
'--max-worker-count',
required=False,
type=int,
help=('The maximum number of workers to be used for parallel processing. '
'Defaults to 10 if not specified.'))
parser.add_argument(
'--batch-size',
required=False,
type=int,
help=('The number of records per batch. The service will buffer '
'batch_size number of records in memory before invoking TensorFlow.'
' Defaults to 64 if not specified.'))
flags.SIGNATURE_NAME.AddToParser(parser)
flags.RUNTIME_VERSION.AddToParser(parser)
labels_util.AddCreateLabelsFlags(parser)
@base.ReleaseTracks(base.ReleaseTrack.GA,
base.ReleaseTrack.BETA)
class Prediction(base.Command):
"""Start an AI Platform batch prediction job."""
@staticmethod
def Args(parser):
_AddSubmitPredictionArgs(parser)
parser.display_info.AddFormat(jobs_util.JOB_FORMAT)
def Run(self, args):
data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice(
args.data_format)
jobs_client = jobs.JobsClient()
labels = jobs_util.ParseCreateLabels(jobs_client, args)
return jobs_util.SubmitPrediction(
jobs_client, args.job,
model_dir=args.model_dir,
model=args.model,
version=args.version,
input_paths=args.input_paths,
data_format=data_format.name,
output_path=args.output_path,
region=args.region,
runtime_version=args.runtime_version,
max_worker_count=args.max_worker_count,
batch_size=args.batch_size,
signature_name=args.signature_name,
labels=labels)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class PredictionAlpha(base.Command):
"""Start an AI Platform batch prediction job."""
@staticmethod
def Args(parser):
_AddSubmitPredictionArgs(parser)
_AddAcceleratorFlags(parser)
parser.display_info.AddFormat(jobs_util.JOB_FORMAT)
def Run(self, args):
data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice(
args.data_format)
jobs_client = jobs.JobsClient()
labels = jobs_util.ParseCreateLabels(jobs_client, args)
return jobs_util.SubmitPrediction(
jobs_client, args.job,
model_dir=args.model_dir,
model=args.model,
version=args.version,
input_paths=args.input_paths,
data_format=data_format.name,
output_path=args.output_path,
region=args.region,
runtime_version=args.runtime_version,
max_worker_count=args.max_worker_count,
batch_size=args.batch_size,
signature_name=args.signature_name,
labels=labels,
accelerator_type=args.accelerator_type,
accelerator_count=args.accelerator_count)

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs submit training command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.compute import flags as compute_flags
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
from googlecloudsdk.command_lib.util.args import labels_util
def _AddSubmitTrainingArgs(parser):
"""Add arguments for `jobs submit training` command."""
flags.JOB_NAME.AddToParser(parser)
flags.PACKAGE_PATH.AddToParser(parser)
flags.PACKAGES.AddToParser(parser)
flags.GetModuleNameFlag(required=False).AddToParser(parser)
compute_flags.AddRegionFlag(parser, 'machine learning training job',
'submit')
flags.CONFIG.AddToParser(parser)
flags.STAGING_BUCKET.AddToParser(parser)
flags.GetJobDirFlag(upload_help=True).AddToParser(parser)
flags.GetUserArgs(local=False).AddToParser(parser)
jobs_util.ScaleTierFlagMap().choice_arg.AddToParser(parser)
flags.RUNTIME_VERSION.AddToParser(parser)
flags.AddPythonVersionFlag(parser, 'during training')
flags.TRAINING_SERVICE_ACCOUNT.AddToParser(parser)
flags.ENABLE_WEB_ACCESS.AddToParser(parser)
sync_group = parser.add_mutually_exclusive_group()
# TODO(b/36195821): Use the flag deprecation machinery when it supports the
# store_true action
sync_group.add_argument(
'--async', action='store_true', dest='async_', help=(
'(DEPRECATED) Display information about the operation in progress '
'without waiting for the operation to complete. '
'Enabled by default and can be omitted; use `--stream-logs` to run '
'synchronously.'))
sync_group.add_argument(
'--stream-logs',
action='store_true',
help=('Block until job completion and stream the logs while the job runs.'
'\n\n'
'Note that even if command execution is halted, the job will still '
'run until cancelled with\n\n'
' $ gcloud ai-platform jobs cancel JOB_ID'))
labels_util.AddCreateLabelsFlags(parser)
def _GetAndValidateKmsKey(args):
"""Parse CMEK resource arg, and check if the arg was partially specified."""
if hasattr(args.CONCEPTS, 'kms_key'):
kms_ref = args.CONCEPTS.kms_key.Parse()
if kms_ref:
return kms_ref.RelativeName()
else:
for keyword in ['kms-key', 'kms-keyring', 'kms-location', 'kms-project']:
if getattr(args, keyword.replace('-', '_'), None):
raise exceptions.InvalidArgumentException(
'--kms-key', 'Encryption key not fully specified.')
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Train(base.Command):
"""Submit an AI Platform training job."""
_SUPPORT_TPU_TF_VERSION = False
@classmethod
def Args(cls, parser):
_AddSubmitTrainingArgs(parser)
flags.AddCustomContainerFlags(
parser, support_tpu_tf_version=cls._SUPPORT_TPU_TF_VERSION)
flags.AddKmsKeyFlag(parser, 'job')
parser.display_info.AddFormat(jobs_util.JOB_FORMAT)
def Run(self, args):
stream_logs = jobs_util.GetStreamLogs(args.async_, args.stream_logs)
scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier)
scale_tier_name = scale_tier.name if scale_tier else None
jobs_client = jobs.JobsClient()
labels = jobs_util.ParseCreateLabels(jobs_client, args)
custom_container_config = (
jobs_util.TrainingCustomInputServerConfig.FromArgs(
args, self._SUPPORT_TPU_TF_VERSION))
custom_container_config.ValidateConfig()
job = jobs_util.SubmitTraining(
jobs_client,
args.job,
job_dir=args.job_dir,
staging_bucket=args.staging_bucket,
packages=args.packages,
package_path=args.package_path,
scale_tier=scale_tier_name,
config=args.config,
module_name=args.module_name,
runtime_version=args.runtime_version,
python_version=args.python_version,
network=args.network if hasattr(args, 'network') else None,
service_account=args.service_account,
labels=labels,
stream_logs=stream_logs,
user_args=args.user_args,
kms_key=_GetAndValidateKmsKey(args),
custom_train_server_config=custom_container_config,
enable_web_access=args.enable_web_access)
# If the job itself failed, we will return a failure status.
if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
self.exit_code = 1
return job
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class TrainAlphaBeta(Train):
"""Submit an AI Platform training job."""
_SUPPORT_TPU_TF_VERSION = True
@classmethod
def Args(cls, parser):
_AddSubmitTrainingArgs(parser)
flags.AddKmsKeyFlag(parser, 'job')
flags.NETWORK.AddToParser(parser)
flags.AddCustomContainerFlags(
parser, support_tpu_tf_version=cls._SUPPORT_TPU_TF_VERSION)
parser.display_info.AddFormat(jobs_util.JOB_FORMAT)
_DETAILED_HELP = {
'DESCRIPTION':
r"""Submit an AI Platform training job.
This creates temporary files and executes Python code staged
by a user on Cloud Storage. Model code can either be
specified with a path, e.g.:
$ {command} my_job \
--module-name trainer.task \
--staging-bucket gs://my-bucket \
--package-path /my/code/path/trainer \
--packages additional-dep1.tar.gz,dep2.whl
Or by specifying an already built package:
$ {command} my_job \
--module-name trainer.task \
--staging-bucket gs://my-bucket \
--packages trainer-0.0.1.tar.gz,additional-dep1.tar.gz,dep2.whl
If `--package-path=/my/code/path/trainer` is specified and there is a
`setup.py` file at `/my/code/path/setup.py`, the setup file will be invoked
with `sdist` and the generated tar files will be uploaded to Cloud Storage.
Otherwise, a temporary `setup.py` file will be generated for the build.
By default, this command runs asynchronously; it exits once the job is
successfully submitted.
To follow the progress of your job, pass the `--stream-logs` flag (note that
even with the `--stream-logs` flag, the job will continue to run after this
command exits and must be cancelled with `gcloud ai-platform jobs cancel JOB_ID`).
For more information, see:
https://cloud.google.com/ai-platform/training/docs/overview
"""
}
Train.detailed_help = _DETAILED_HELP

View File

@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs update command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import log
DETAILED_HELP = {
'EXAMPLES':
"""\
To remove all labels in the AI Platform job named ``my-job'', run:
$ {command} my-job --clear-labels
"""
}
def _AddUpdateArgs(parser):
"""Get arguments for the `ai-platform jobs update` command."""
flags.JOB_NAME.AddToParser(parser)
labels_util.AddUpdateLabelsFlags(parser)
class Update(base.UpdateCommand):
"""Update an AI Platform job."""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
_AddUpdateArgs(parser)
def Run(self, args):
jobs_client = jobs.JobsClient()
updated_job = jobs_util.Update(jobs_client, args)
log.UpdatedResource(args.job, kind='ml engine job')
return updated_job

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform local."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
class Local(base.Group):
"""AI Platform Local commands."""
pass

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform local predict command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import local_utils
from googlecloudsdk.command_lib.ml_engine import predict_utilities
from googlecloudsdk.core import log
def _AddLocalPredictArgs(parser):
"""Add arguments for `gcloud ai-platform local predict` command."""
parser.add_argument('--model-dir', required=True, help='Path to the model.')
flags.FRAMEWORK_MAPPER.choice_arg.AddToParser(parser)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--json-request',
help="""\
Path to a local file containing the body of JSON request.
An example of a JSON request:
{
"instances": [
{"x": [1, 2], "y": [3, 4]},
{"x": [-1, -2], "y": [-3, -4]}
]
}
This flag accepts "-" for stdin.
""")
group.add_argument(
'--json-instances',
help="""\
Path to a local file from which instances are read.
Instances are in JSON format; newline delimited.
An example of the JSON instances file:
{"images": [0.0, ..., 0.1], "key": 3}
{"images": [0.0, ..., 0.1], "key": 2}
...
This flag accepts "-" for stdin.
""")
group.add_argument(
'--text-instances',
help="""\
Path to a local file from which instances are read.
Instances are in UTF-8 encoded text format; newline delimited.
An example of the text instances file:
107,4.9,2.5,4.5,1.7
100,5.7,2.8,4.1,1.3
...
This flag accepts "-" for stdin.
""")
flags.SIGNATURE_NAME.AddToParser(parser)
class Predict(base.Command):
"""Run prediction locally."""
@staticmethod
def Args(parser):
_AddLocalPredictArgs(parser)
def Run(self, args):
framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework)
framework_flag = framework.name.lower() if framework else 'tensorflow'
if args.signature_name is None:
log.status.Print('If the signature defined in the model is '
'not serving_default then you must specify it via '
'--signature-name flag, otherwise the command may fail.')
results = local_utils.RunPredict(
args.model_dir,
json_request=args.json_request,
json_instances=args.json_instances,
text_instances=args.text_instances,
framework=framework_flag,
signature_name=args.signature_name)
if not args.IsSpecified('format'):
# default format is based on the response.
if isinstance(results, list):
predictions = results
else:
predictions = results.get('predictions')
args.format = predict_utilities.GetDefaultFormat(predictions)
return results
_DETAILED_HELP = {
'DESCRIPTION':
"""\
*{command}* performs prediction locally with the given instances. It requires the
[TensorFlow SDK](https://www.tensorflow.org/install) be installed locally. The
output format mirrors `gcloud ai-platform predict` (online prediction).
You cannot use this command with custom prediction routines.
"""
}
Predict.detailed_help = _DETAILED_HELP

View File

@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform local train command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import os
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import local_train
from googlecloudsdk.core import log
from googlecloudsdk.core.util import files
_BAD_FLAGS_WARNING_MESSAGE = """\
{flag} is ignored if --distributed is not provided.
Did you mean to run distributed training?\
"""
class RunLocal(base.Command):
r"""Run an AI Platform training job locally.
This command runs the specified module in an environment
similar to that of a live AI Platform Training Job.
This is especially useful in the case of testing distributed models,
as it allows you to validate that you are properly interacting with the
AI Platform cluster configuration. If your model expects a specific
number of parameter servers or workers (i.e. you expect to use the CUSTOM
machine type), use the --parameter-server-count and --worker-count flags to
further specify the desired cluster configuration, just as you would in
your cloud training job configuration:
$ {command} --module-name trainer.task \
--package-path /path/to/my/code/trainer \
--distributed \
--parameter-server-count 4 \
--worker-count 8
Unlike submitting a training job, the --package-path parameter can be
omitted, and will use your current working directory.
AI Platform Training sets a TF_CONFIG environment variable on each VM in
your training job. You can use TF_CONFIG to access the cluster description
and the task description for each VM.
Learn more about TF_CONFIG:
https://cloud.google.com/ai-platform/training/docs/distributed-training-details.
"""
@staticmethod
def Args(parser):
"""Register flags for this command."""
flags.PACKAGE_PATH.AddToParser(parser)
flags.GetModuleNameFlag().AddToParser(parser)
flags.DISTRIBUTED.AddToParser(parser)
flags.EVALUATORS.AddToParser(parser)
flags.PARAM_SERVERS.AddToParser(parser)
flags.GetJobDirFlag(upload_help=False, allow_local=True).AddToParser(parser)
flags.WORKERS.AddToParser(parser)
flags.START_PORT.AddToParser(parser)
flags.GetUserArgs(local=True).AddToParser(parser)
def Run(self, args):
"""This is what gets called when the user runs this command.
Args:
args: an argparse namespace. All the arguments that were provided to this
command invocation.
Returns:
Some value that we want to have printed later.
"""
package_path = args.package_path or files.GetCWD()
# Mimic behavior of ai-platform jobs submit training
package_root = os.path.dirname(os.path.abspath(package_path))
user_args = args.user_args or []
if args.job_dir:
user_args.extend(('--job-dir', args.job_dir))
worker_count = 2 if args.worker_count is None else args.worker_count
ps_count = 2 if args.parameter_server_count is None else args.parameter_server_count
if args.distributed:
retval = local_train.RunDistributed(
args.module_name,
package_root,
ps_count,
worker_count,
args.evaluator_count or 0,
args.start_port,
user_args=user_args)
else:
if args.parameter_server_count:
log.warning(_BAD_FLAGS_WARNING_MESSAGE.format(
flag='--parameter-server-count'))
if args.worker_count:
log.warning(_BAD_FLAGS_WARNING_MESSAGE.format(flag='--worker-count'))
retval = local_train.MakeProcess(
args.module_name,
package_root,
args=user_args,
task_type=local_train.GetPrimaryNodeName())
# Don't raise an exception because the users will already see the message.
# We want this to mimic calling the script directly as much as possible.
self.exit_code = retval

View File

@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform locations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class Locations(base.Group):
"""Query AI Platform location capabilities.
The {command} command group lets you query AI Platform locations.
"""

View File

@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform models describe command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import locations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
_COLLECTION = 'ml.projects.locations'
def _AddDescribeArgs(parser):
flags.GetLocationResourceArg().AddToParser(parser)
flags.GetRegionArg().AddToParser(parser)
parser.display_info.AddFormat('json')
def _Run(args):
with endpoint_util.MlEndpointOverrides(region=args.region):
return locations.LocationsClient().Get(args.location)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class DescribeAlpha(base.DescribeCommand):
"""Display AI Platform capabilities in a location."""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,24 @@
- release_tracks: [ALPHA]
help_text:
brief: List AI Platform Locations.
description: List AI Platform Locations.
request:
collection: ml.projects.locations
response:
id_field: name
arguments:
resource:
help_text: The parent project of the locations you want to list.
spec: !REF googlecloudsdk.command_lib.ml_engine.resources:project
output:
format: |
table(
name,
capabilities.type,
capabilities.availableAccelerators.list(separator=',')
)

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
class Models(base.Group):
"""AI Platform Models commands.
An AI Platform model is a container representing an ML application or
service. A model may contain multiple versions which act as the
implementation of the service. See also:
$ {parent_command} versions --help.
For more information, please see
https://cloud.google.com/ml/docs/concepts/technical-overview#model
"""
pass

View File

@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command to add IAM policy binding for a model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
def _AddIamPolicyBindingFlags(parser, add_condition=False):
flags.GetModelName().AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
iam_util.AddArgsForAddIamPolicyBinding(
parser, flags.MlEngineIamRolesCompleter, add_condition=add_condition)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
return models_util.AddIamPolicyBinding(models.ModelsClient(), args.model,
args.member, args.role)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class AddIamPolicyBinding(base.Command):
"""Add IAM policy binding to a model."""
detailed_help = iam_util.GetDetailedHelpForAddIamPolicyBinding(
'model', 'my_model', role='roles/ml.admin', condition=False)
@staticmethod
def Args(parser):
_AddIamPolicyBindingFlags(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA)
class AddIamPolicyBindingBeta(AddIamPolicyBinding):
"""Add IAM policy binding to a model."""
detailed_help = iam_util.GetDetailedHelpForAddIamPolicyBinding(
'model', 'my_model', role='roles/ml.admin', condition=False)
@staticmethod
def Args(parser):
_AddIamPolicyBindingFlags(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class AddIamPolicyBindingAlpha(base.Command):
"""Adds IAM policy binding to a model.
Adds a policy binding to the IAM policy of a ML engine model, given a model ID
and the binding. One binding consists of a member, a role, and an optional
condition.
"""
detailed_help = iam_util.GetDetailedHelpForAddIamPolicyBinding(
'model', 'my_model', role='roles/ml.admin', condition=True)
@staticmethod
def Args(parser):
_AddIamPolicyBindingFlags(parser, add_condition=True)
def Run(self, args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
condition = iam_util.ValidateAndExtractCondition(args)
iam_util.ValidateMutexConditionAndPrimitiveRoles(condition, args.role)
return models_util.AddIamPolicyBindingWithCondition(
models.ModelsClient(),
args.model,
args.member,
args.role,
condition)

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform models create command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import constants
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import log
_REGION_FLAG_HELPTEXT = """\
Google Cloud region of the regional endpoint to use for this command.
If you specify this flag, do not specify `--regions`.
If you specify `--region=global`, the model will be deployed to 'us-central1'
by default using the global endpoint. Please use `--regions` only if you want
to change the region where the model will be deployed against the global
endpoint.
If both flags are unspecified and you don't set ``ai_platform/region'', you will
be prompted for region of the regional endpoint.
Learn more about regional endpoints and see a list of available regions:
https://cloud.google.com/ai-platform/prediction/docs/regional-endpoints
"""
def _AddCreateArgs(parser,
support_console_logging=False):
"""Get arguments for the `ai-platform models create` command."""
flags.GetModelName().AddToParser(parser)
flags.GetDescriptionFlag('model').AddToParser(parser)
region_group = parser.add_mutually_exclusive_group()
region_group.add_argument(
'--region',
choices=constants.SUPPORTED_REGIONS_WITH_GLOBAL,
help=_REGION_FLAG_HELPTEXT)
region_group.add_argument(
'--regions',
metavar='REGION',
type=arg_parsers.ArgList(min_length=1),
help="""\
The Google Cloud region where the model will be deployed (currently only a
single region is supported) against the global endpoint.
If you specify this flag, do not specify `--region`.
Defaults to 'us-central1' while using the global endpoint.
""")
parser.add_argument(
'--enable-logging',
action='store_true',
help='If set, enables StackDriver Logging for online prediction. These '
'logs are like standard server access logs, containing information '
'such as timestamps and latency for each request.')
if support_console_logging:
parser.add_argument(
'--enable-console-logging',
action='store_true',
help='If set, enables StackDriver Logging of stderr and stdout streams '
'for online prediction. These logs are more verbose than the '
'standard access logs and can be helpful for debugging.')
labels_util.AddCreateLabelsFlags(parser)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Create(base.CreateCommand):
"""Create a new AI Platform model."""
@staticmethod
def Args(parser):
_AddCreateArgs(parser)
def _Run(self, args, support_console_logging=False):
region, model_regions = models_util.GetModelRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
models_client = models.ModelsClient()
labels = models_util.ParseCreateLabels(models_client, args)
enable_console_logging = (
support_console_logging and args.enable_console_logging)
model = models_util.Create(
models_client,
args.model,
model_regions,
enable_logging=args.enable_logging,
enable_console_logging=enable_console_logging,
labels=labels,
description=args.description)
log.CreatedResource(model.name, kind='ai platform model')
def Run(self, args):
self._Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA, base.ReleaseTrack.ALPHA)
class CreateBeta(Create):
"""Create a new AI Platform model."""
@staticmethod
def Args(parser):
_AddCreateArgs(parser, support_console_logging=True)
def Run(self, args):
self._Run(args, support_console_logging=True)

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform models delete command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
def _AddDeleteArgs(parser):
flags.GetModelName().AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
models_client = models.ModelsClient()
operations_client = operations.OperationsClient()
return models_util.Delete(models_client, operations_client, args.model)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Delete(base.DeleteCommand):
r"""Delete an existing AI Platform model.
## EXAMPLES
To delete all models matching the regular expression `vision[0-9]+`, run:
$ {parent_command} list --uri \
--filter 'name ~ vision[0-9]+' |
xargs -n 1 {command}
"""
@staticmethod
def Args(parser):
_AddDeleteArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA, base.ReleaseTrack.ALPHA)
class DeleteBeta(Delete):
r"""Delete an existing AI Platform model.
## EXAMPLES
To delete all models matching the regular expression `vision[0-9]+`, run:
$ {parent_command} list --uri \
--filter 'name ~ vision[0-9]+' |
xargs -n 1 {command}
"""
@staticmethod
def Args(parser):
_AddDeleteArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform models describe command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
_COLLECTION = 'ml.models'
def _AddDescribeArgs(parser):
flags.GetModelName().AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
return models.ModelsClient().Get(args.model)
# TODO(b/62998601): don't repeat the first sentence due. Also if b/62998171 is
# resolved this should be obsolete.
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Describe(base.DescribeCommand):
"""Describe an existing AI Platform model.
Describe an existing AI Platform model.
If you would like to see all versions of a model, use
`gcloud ai-platform versions list`.
"""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA, base.ReleaseTrack.ALPHA)
class DescribeBeta(base.DescribeCommand):
"""Describe an existing AI Platform model.
Describe an existing AI Platform model.
If you would like to see all versions of a model, use
`gcloud ai-platform versions list`.
"""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fetch the IAM policy for a model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
def _AddGetIamPolicyArgs(parser):
flags.GetModelResourceArg(
positional=True, required=True,
verb='to set IAM policy for').AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
base.URI_FLAG.RemoveFromParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
return models_util.GetIamPolicy(models.ModelsClient(), args.model)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class GetIamPolicyBeta(base.ListCommand):
"""Get the IAM policy for a model.
Gets the IAM policy for the given model.
Returns an empty policy if the resource does not have a policy set.
## EXAMPLES
The following command gets the IAM policy for the model `my_model`:
$ {command} my_model
"""
@staticmethod
def Args(parser):
_AddGetIamPolicyArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class GetIamPolicy(base.ListCommand):
"""Get the IAM policy for a model.
Gets the IAM policy for the given model.
Returns an empty policy if the resource does not have a policy set.
## EXAMPLES
The following command gets the IAM policy for the model `my_model`:
$ {command} my_model
"""
@staticmethod
def Args(parser):
_AddGetIamPolicyArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform models list command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.core import resources
_COLLECTION = 'ml.models'
_DEFAULT_FORMAT = """
table(
name.basename(),
defaultVersion.name.basename()
)
"""
def _GetUri(model):
ref = resources.REGISTRY.ParseRelativeName(
model.name, models_util.MODELS_COLLECTION)
return ref.SelfLink()
def _AddListArgs(parser):
parser.display_info.AddFormat(_DEFAULT_FORMAT)
parser.display_info.AddUriFunc(_GetUri)
flags.GetRegionArg(include_global=True).AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
return models_util.List(models.ModelsClient())
@base.ReleaseTracks(base.ReleaseTrack.GA)
class List(base.ListCommand):
"""List existing AI Platform models."""
@staticmethod
def Args(parser):
_AddListArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA, base.ReleaseTrack.ALPHA)
class ListBeta(base.ListCommand):
"""List existing AI Platform models."""
@staticmethod
def Args(parser):
_AddListArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,160 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Remove IAM Policy Binding."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
def _GetRemoveIamPolicyBindingArgs(parser, add_condition=False):
iam_util.AddArgsForRemoveIamPolicyBinding(parser, add_condition=add_condition)
flags.GetModelResourceArg(
required=True,
verb='for which to remove IAM policy binding from').AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
base.URI_FLAG.RemoveFromParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = models.ModelsClient()
return models_util.RemoveIamPolicyBinding(client, args.model, args.member,
args.role)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class RemoveIamPolicyBinding(base.Command):
"""Removes IAM policy binding from an AI Platform Model resource.
Removes a policy binding from an AI Platform Model. One
binding consists of a member, a role and an optional condition.
See $ {parent_command} get-iam-policy for examples of how to
specify a model resource.
"""
description = 'remove IAM policy binding from an AI Platform model'
detailed_help = iam_util.GetDetailedHelpForRemoveIamPolicyBinding(
'model', 'my_model', role='roles/ml.admin', condition=False)
@staticmethod
def Args(parser):
"""Register flags for this command.
Args:
parser: An argparse.ArgumentParser-like object. It is mocked out in order
to capture some information, but behaves like an ArgumentParser.
"""
_GetRemoveIamPolicyBindingArgs(parser, add_condition=False)
def Run(self, args):
"""This is what gets called when the user runs this command.
Args:
args: an argparse namespace. All the arguments that were provided to this
command invocation.
Returns:
The specified function with its description and configured filter.
"""
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA)
class RemoveIamPolicyBindingBeta(base.Command):
"""Removes IAM policy binding from an AI Platform Model resource.
Removes a policy binding from an AI Platform Model. One
binding consists of a member, a role and an optional condition.
See $ {parent_command} get-iam-policy for examples of how to
specify a model resource.
"""
description = 'remove IAM policy binding from an AI Platform model'
detailed_help = iam_util.GetDetailedHelpForRemoveIamPolicyBinding(
'model', 'my_model', role='roles/ml.admin', condition=False)
@staticmethod
def Args(parser):
"""Register flags for this command.
Args:
parser: An argparse.ArgumentParser-like object. It is mocked out in order
to capture some information, but behaves like an ArgumentParser.
"""
_GetRemoveIamPolicyBindingArgs(parser, add_condition=False)
def Run(self, args):
"""This is what gets called when the user runs this command.
Args:
args: an argparse namespace. All the arguments that were provided to this
command invocation.
Returns:
The specified function with its description and configured filter.
"""
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class RemoveIamPolicyBindingAlpha(base.Command):
r"""Removes IAM policy binding from an AI Platform Model resource.
Remove an IAM policy binding from the IAM policy of a ML model. One binding
consists of a member, a role, and an optional condition.
See $ {parent_command} get-iam-policy for examples of how to
specify a model resource.
"""
description = 'remove IAM policy binding from an AI Platform model'
detailed_help = iam_util.GetDetailedHelpForRemoveIamPolicyBinding(
'model', 'my_model', role='roles/ml.admin', condition=False)
@staticmethod
def Args(parser):
"""Register flags for this command.
Args:
parser: An argparse.ArgumentParser-like object. It is mocked out in order
to capture some information, but behaves like an ArgumentParser.
"""
_GetRemoveIamPolicyBindingArgs(parser, add_condition=True)
def Run(self, args):
"""This is what gets called when the user runs this command.
Args:
args: an argparse namespace. All the arguments that were provided to this
command invocation.
Returns:
The specified function with its description and configured filter.
"""
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
condition = iam_util.ValidateAndExtractCondition(args)
iam_util.ValidateMutexConditionAndPrimitiveRoles(condition, args.role)
return models_util.RemoveIamPolicyBindingWithCondition(
models.ModelsClient(), args.model, args.member, args.role, condition)

View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Set the IAM policy for a model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
def _AddSetIamPolicyArgs(parser):
flags.GetModelResourceArg(
positional=True, required=True,
verb='to set IAM policy for').AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
iam_util.AddArgForPolicyFile(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
return models_util.SetIamPolicy(models.ModelsClient(), args.model,
args.policy_file)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class SetIamPolicyGA(base.Command):
"""Set the IAM policy for a model.
Sets the IAM policy for the given model as defined in a JSON or YAML file.
See https://cloud.google.com/iam/docs/managing-policies for details of
the policy file format and contents.
## EXAMPLES
The following command will read am IAM policy defined in a JSON file
'policy.json' and set it for the model `my_model`:
$ {command} my_model policy.json
"""
@staticmethod
def Args(parser):
_AddSetIamPolicyArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class SetIamPolicyBeta(base.Command):
"""Set the IAM policy for a model.
Sets the IAM policy for the given model as defined in a JSON or YAML file.
See https://cloud.google.com/iam/docs/managing-policies for details of
the policy file format and contents.
## EXAMPLES
The following command will read am IAM policy defined in a JSON file
'policy.json' and set it for the model `my_model`:
$ {command} my_model policy.json
"""
@staticmethod
def Args(parser):
_AddSetIamPolicyArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform models update command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import log
def _AddUpdateArgs(parser):
"""Get arguments for the `ai-platform models update` command."""
flags.GetModelName().AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
flags.GetDescriptionFlag('model').AddToParser(parser)
labels_util.AddUpdateLabelsFlags(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
models_client = models.ModelsClient()
operations_client = operations.OperationsClient()
models_util.Update(models_client, operations_client, args)
log.UpdatedResource(args.model, kind='ai platform model')
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class UpdateBeta(base.UpdateCommand):
"""Update an existing AI Platform model."""
@staticmethod
def Args(parser):
_AddUpdateArgs(parser)
def Run(self, args):
_Run(args)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Update(base.UpdateCommand):
"""Update an existing AI Platform model."""
@staticmethod
def Args(parser):
_AddUpdateArgs(parser)
def Run(self, args):
_Run(args)

View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
class MlEngineOperations(base.Group):
"""Manage AI Platform operations."""

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform operations cancel command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import operations_util
def _AddCancelArgs(parser):
flags.OPERATION_NAME.AddToParser(parser)
flags.GetRegionArg().AddToParser(parser)
def _Run(args):
with endpoint_util.MlEndpointOverrides(region=args.region):
client = operations.OperationsClient()
return operations_util.Cancel(client, args.operation)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Cancel(base.SilentCommand):
"""Cancel an AI Platform operation."""
@staticmethod
def Args(parser):
_AddCancelArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class CancelBeta(base.SilentCommand):
"""Cancel an AI Platform operation."""
@staticmethod
def Args(parser):
_AddCancelArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs describe command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import operations_util
def _AddDescribeArgs(parser):
flags.OPERATION_NAME.AddToParser(parser)
flags.GetRegionArg().AddToParser(parser)
def _Run(args):
with endpoint_util.MlEndpointOverrides(region=args.region):
client = operations.OperationsClient()
return operations_util.Describe(client, args.operation)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class DescribeBeta(base.DescribeCommand):
"""Describe an AI Platform operation."""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Describe(base.DescribeCommand):
"""Describe an AI Platform operation."""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform operations list command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import operations_util
def _AddListArgs(parser):
list_format = """\
table(
name.basename(),
metadata.operationType,
done
)
"""
parser.display_info.AddFormat(list_format)
flags.GetRegionArg().AddToParser(parser)
def _Run(args):
with endpoint_util.MlEndpointOverrides(region=args.region):
client = operations.OperationsClient()
return operations_util.List(client)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class List(base.ListCommand):
"""List existing AI Platform jobs."""
@staticmethod
def Args(parser):
_AddListArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class ListBeta(base.ListCommand):
"""List existing AI Platform jobs."""
@staticmethod
def Args(parser):
_AddListArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform operations wait command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import operations_util
def _AddWaitArgs(parser):
flags.OPERATION_NAME.AddToParser(parser)
flags.GetRegionArg().AddToParser(parser)
def _Run(args):
with endpoint_util.MlEndpointOverrides(region=args.region):
client = operations.OperationsClient()
return operations_util.Wait(client, args.operation)
_DETAILED_HELP = {
'DESCRIPTION':
"""\
Wait for an AI Platform operation to complete.
Given an operation ID, this command polls the operation and blocks
until it completes. At completion, the operation message is printed
(which includes the operation response).
"""
}
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Wait(base.CreateCommand):
"""Wait for an AI Platform operation to complete."""
detailed_help = _DETAILED_HELP
@staticmethod
def Args(parser):
_AddWaitArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class WaitBeta(base.CreateCommand):
"""Wait for an AI Platform operation to complete."""
detailed_help = _DETAILED_HELP
@staticmethod
def Args(parser):
_AddWaitArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform predict command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import predict
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import predict_utilities
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.core import log
INPUT_INSTANCES_LIMIT = 100
def _AddPredictArgs(parser):
"""Register flags for this command."""
parser.add_argument('--model', required=True, help='Name of the model.')
parser.add_argument(
'--version',
help="""\
Model version to be used.
If unspecified, the default version of the model will be used. To list model
versions run
$ {parent_command} versions list
""")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--json-request',
help="""\
Path to a local file containing the body of JSON request.
An example of a JSON request:
{
"instances": [
{"x": [1, 2], "y": [3, 4]},
{"x": [-1, -2], "y": [-3, -4]}
]
}
This flag accepts "-" for stdin.
""")
group.add_argument(
'--json-instances',
help="""\
Path to a local file from which instances are read.
Instances are in JSON format; newline delimited.
An example of the JSON instances file:
{"images": [0.0, ..., 0.1], "key": 3}
{"images": [0.0, ..., 0.1], "key": 2}
...
This flag accepts "-" for stdin.
""")
group.add_argument(
'--text-instances',
help="""\
Path to a local file from which instances are read.
Instances are in UTF-8 encoded text format; newline delimited.
An example of the text instances file:
107,4.9,2.5,4.5,1.7
100,5.7,2.8,4.1,1.3
...
This flag accepts "-" for stdin.
""")
flags.GetRegionArg(include_global=True).AddToParser(parser)
flags.SIGNATURE_NAME.AddToParser(parser)
def _Run(args):
"""This is what gets called when the user runs this command.
Args:
args: an argparse namespace. All the arguments that were provided to this
command invocation.
Returns:
A json object that contains predictions.
"""
instances = predict_utilities.ReadInstancesFromArgs(
args.json_request,
args.json_instances,
args.text_instances,
limit=INPUT_INSTANCES_LIMIT)
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
args.model, args.version)
if (args.signature_name is None and
predict_utilities.CheckRuntimeVersion(args.model, args.version)):
log.status.Print(
'You are running on a runtime version >= 1.8. '
'If the signature defined in the model is '
'not serving_default then you must specify it via '
'--signature-name flag, otherwise the command may fail.')
results = predict.Predict(
model_or_version_ref, instances, signature_name=args.signature_name)
if not args.IsSpecified('format'):
# default format is based on the response.
args.format = predict_utilities.GetDefaultFormat(
results.get('predictions'))
return results
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Predict(base.Command):
"""Run AI Platform online prediction.
`{command}` sends a prediction request to AI Platform for the given
instances. This command will read up to 100 instances, though the service
itself will accept instances up to the payload limit size (currently,
1.5MB). If you are predicting on more instances, you should use batch
prediction via
$ {parent_command} jobs submit prediction.
"""
@staticmethod
def Args(parser):
"""Register flags for this command."""
_AddPredictArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class PredictBeta(base.Command):
"""Run AI Platform online prediction.
`{command}` sends a prediction request to AI Platform for the given
instances. This command will read up to 100 instances, though the service
itself will accept instances up to the payload limit size (currently,
1.5MB). If you are predicting on more instances, you should use batch
prediction via
$ {parent_command} jobs submit prediction.
"""
@staticmethod
def Args(parser):
"""Register flags for this command."""
_AddPredictArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command group for ai-platform versions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.calliope import base
class Versions(base.Group):
"""AI Platform Versions commands.
A version is an implementation of a model, represented as a serialized
TensorFlow graph with trained parameters.
When you communicate with AI Platform services, you use the
combination of the model, version, and current project to identify a
specific model implementation that is deployed in the cloud.
"""

View File

@@ -0,0 +1,265 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform versions create command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.ml_engine import versions_util
from googlecloudsdk.command_lib.util.args import labels_util
DETAILED_HELP = {
'EXAMPLES':
"""\
To create an AI Platform version model with the version ID 'versionId'
and with the name 'model-name', run:
$ {command} versionId --model=model-name
""",
}
def _AddCreateArgs(parser):
"""Add common arguments for `versions create` command."""
flags.GetModelName(positional=False, required=True).AddToParser(parser)
flags.GetDescriptionFlag('version').AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
flags.VERSION_NAME.AddToParser(parser)
base.Argument(
'--origin',
help="""\
Location of ```model/``` "directory" (see
https://cloud.google.com/ai-platform/prediction/docs/deploying-models#upload-model).
This overrides `deploymentUri` in the `--config` file. If this flag is
not passed, `deploymentUri` *must* be specified in the file from
`--config`.
Can be a Cloud Storage (`gs://`) path or local file path (no
prefix). In the latter case the files will be uploaded to Cloud
Storage and a `--staging-bucket` argument is required.
""").AddToParser(parser)
flags.RUNTIME_VERSION.AddToParser(parser)
base.ASYNC_FLAG.AddToParser(parser)
flags.STAGING_BUCKET.AddToParser(parser)
base.Argument(
'--config',
help="""\
Path to a YAML configuration file containing configuration parameters
for the
[Version](https://cloud.google.com/ai-platform/prediction/docs/reference/rest/v1/projects.models.versions)
to create.
The file is in YAML format. Note that not all attributes of a version
are configurable; available attributes (with example values) are:
description: A free-form description of the version.
deploymentUri: gs://path/to/source
runtimeVersion: '2.1'
# Set only one of either manualScaling or autoScaling.
manualScaling:
nodes: 10 # The number of nodes to allocate for this model.
autoScaling:
minNodes: 0 # The minimum number of nodes to allocate for this model.
labels:
user-defined-key: user-defined-value
The name of the version must always be specified via the required
VERSION argument.
Only one of manualScaling or autoScaling can be specified. If both
are specified in same yaml file an error will be returned.
If an option is specified both in the configuration file and via
command-line arguments, the command-line arguments override the
configuration file.
"""
).AddToParser(parser)
labels_util.AddCreateLabelsFlags(parser)
flags.FRAMEWORK_MAPPER.choice_arg.AddToParser(parser)
flags.AddPythonVersionFlag(parser, 'when creating the version')
flags.AddMachineTypeFlagToParser(parser)
flags.GetAcceleratorFlag().AddToParser(parser)
flags.AddAutoScalingFlags(parser)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class CreateGA(base.CreateCommand):
"""Create a new AI Platform version.
Creates a new version of an AI Platform model.
For more details on managing AI Platform models and versions see
https://cloud.google.com/ai-platform/prediction/docs/managing-models-jobs
"""
detailed_help = DETAILED_HELP
@staticmethod
def Args(parser):
_AddCreateArgs(parser)
def Run(self, args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
labels = versions_util.ParseCreateLabels(client, args)
framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework)
accelerator = flags.ParseAcceleratorFlag(args.accelerator)
return versions_util.Create(
client,
operations.OperationsClient(),
args.version,
model=args.model,
origin=args.origin,
staging_bucket=args.staging_bucket,
runtime_version=args.runtime_version,
config_file=args.config,
asyncronous=args.async_,
description=args.description,
labels=labels,
machine_type=args.machine_type,
framework=framework,
python_version=args.python_version,
accelerator_config=accelerator,
min_nodes=args.min_nodes,
max_nodes=args.max_nodes,
metrics=args.metric_targets,
autoscaling_hidden=False)
@base.ReleaseTracks(base.ReleaseTrack.BETA)
class CreateBeta(CreateGA):
"""Create a new AI Platform version.
Creates a new version of an AI Platform model.
For more details on managing AI Platform models and versions see
https://cloud.google.com/ai-platform/prediction/docs/managing-models-jobs
"""
@staticmethod
def Args(parser):
_AddCreateArgs(parser)
flags.SERVICE_ACCOUNT.AddToParser(parser)
flags.AddUserCodeArgs(parser)
flags.AddExplainabilityFlags(parser)
flags.AddContainerFlags(parser)
def Run(self, args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
labels = versions_util.ParseCreateLabels(client, args)
framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework)
accelerator = flags.ParseAcceleratorFlag(args.accelerator)
return versions_util.Create(
client,
operations.OperationsClient(),
args.version,
model=args.model,
origin=args.origin,
staging_bucket=args.staging_bucket,
runtime_version=args.runtime_version,
config_file=args.config,
asyncronous=args.async_,
description=args.description,
labels=labels,
machine_type=args.machine_type,
framework=framework,
python_version=args.python_version,
service_account=args.service_account,
prediction_class=args.prediction_class,
package_uris=args.package_uris,
accelerator_config=accelerator,
explanation_method=args.explanation_method,
num_integral_steps=args.num_integral_steps,
num_paths=args.num_paths,
image=args.image,
command=args.command,
container_args=args.args,
env_vars=args.env_vars,
ports=args.ports,
predict_route=args.predict_route,
health_route=args.health_route,
min_nodes=args.min_nodes,
max_nodes=args.max_nodes,
metrics=args.metric_targets,
containers_hidden=False,
autoscaling_hidden=False)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class CreateAlpha(CreateBeta):
"""Create a new AI Platform version.
Creates a new version of an AI Platform model.
For more details on managing AI Platform models and versions see
https://cloud.google.com/ai-platform/prediction/docs/managing-models-jobs
"""
@staticmethod
def Args(parser):
CreateBeta.Args(parser)
def Run(self, args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
labels = versions_util.ParseCreateLabels(client, args)
framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework)
accelerator = flags.ParseAcceleratorFlag(args.accelerator)
return versions_util.Create(
client,
operations.OperationsClient(),
args.version,
model=args.model,
origin=args.origin,
staging_bucket=args.staging_bucket,
runtime_version=args.runtime_version,
config_file=args.config,
asyncronous=args.async_,
labels=labels,
description=args.description,
machine_type=args.machine_type,
framework=framework,
python_version=args.python_version,
prediction_class=args.prediction_class,
package_uris=args.package_uris,
service_account=args.service_account,
accelerator_config=accelerator,
explanation_method=args.explanation_method,
num_integral_steps=args.num_integral_steps,
num_paths=args.num_paths,
image=args.image,
command=args.command,
container_args=args.args,
env_vars=args.env_vars,
ports=args.ports,
predict_route=args.predict_route,
health_route=args.health_route,
min_nodes=args.min_nodes,
max_nodes=args.max_nodes,
metrics=args.metric_targets,
containers_hidden=False,
autoscaling_hidden=False)

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform versions delete command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.ml_engine import versions_util
def _AddDeleteArgs(parser):
flags.GetModelName(positional=False, required=True).AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
flags.VERSION_NAME.AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
return versions_util.Delete(
client, operations.OperationsClient(), args.version, model=args.model)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Delete(base.DeleteCommand):
"""Delete an existing AI Platform version."""
@staticmethod
def Args(parser):
_AddDeleteArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA, base.ReleaseTrack.ALPHA)
class DeleteBeta(base.DeleteCommand):
"""Delete an existing AI Platform version."""
@staticmethod
def Args(parser):
_AddDeleteArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform versions describe command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.ml_engine import versions_util
def _AddDescribeArgs(parser):
flags.GetModelName(positional=False, required=True).AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
flags.VERSION_NAME.AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
return versions_util.Describe(client, args.version, model=args.model)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Describe(base.DescribeCommand):
"""Describe an existing AI Platform version."""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.BETA, base.ReleaseTrack.ALPHA)
class DescribeBeta(base.DescribeCommand):
"""Describe an existing AI Platform version."""
@staticmethod
def Args(parser):
_AddDescribeArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform versions list command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.ml_engine import versions_util
def _AddListArgs(parser):
flags.GetModelName(positional=False, required=True).AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
parser.display_info.AddFormat(
'table(name.basename(), deploymentUri, state)')
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
return versions_util.List(client, model=args.model)
@base.ReleaseTracks(base.ReleaseTrack.GA)
class List(base.ListCommand):
"""List existing AI Platform versions."""
@staticmethod
def Args(parser):
_AddListArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class ListBeta(base.ListCommand):
"""List existing AI Platform versions."""
@staticmethod
def Args(parser):
_AddListArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform versions set-default command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.ml_engine import versions_util
def _AddSetDefaultArgs(parser):
flags.GetModelName(positional=False, required=True).AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
flags.VERSION_NAME.AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
client = versions_api.VersionsClient()
return versions_util.SetDefault(client, args.version, model=args.model)
_DETAILED_HELP = {
'DESCRIPTION':
"""\
Sets an existing AI Platform version as the default for its model.
*{command}* sets an existing AI Platform version as the default for its
model. Only one version may be the default for a given version.
"""
}
@base.ReleaseTracks(base.ReleaseTrack.GA)
class SetDefault(base.DescribeCommand):
"""Sets an existing AI Platform version as the default for its model."""
detailed_help = _DETAILED_HELP
@staticmethod
def Args(parser):
_AddSetDefaultArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class SetDefaultBeta(SetDefault):
"""Sets an existing AI Platform version as the default for its model."""
@staticmethod
def Args(parser):
_AddSetDefaultArgs(parser)
def Run(self, args):
return _Run(args)

View File

@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform versions update command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import operations
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import endpoint_util
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.ml_engine import versions_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import log
def _AddUpdateArgs(parser):
"""Get arguments for the `ai-platform versions update` command."""
flags.AddVersionResourceArg(parser, 'to update')
flags.GetDescriptionFlag('version').AddToParser(parser)
flags.GetRegionArg(include_global=True).AddToParser(parser)
labels_util.AddUpdateLabelsFlags(parser)
base.Argument(
'--config',
metavar='YAML_FILE',
help="""\
Path to a YAML configuration file containing configuration parameters
for the
[version](https://cloud.google.com/ml/reference/rest/v1/projects.models.versions)
to create.
The file is in YAML format. Note that not all attributes of a version
are configurable; available attributes (with example values) are:
description: A free-form description of the version.
manualScaling:
nodes: 10 # The number of nodes to allocate for this model.
autoScaling:
minNodes: 0 # The minimum number of nodes to allocate for this model.
maxNodes: 1 # The maxinum number of nodes to allocate for this model.
requestLoggingconfig:
bigqueryTableName: someTable # Fully qualified BigQuery table name.
samplingPercentage: 0.5 # Percentage of requests to be logged.
The name of the version must always be specified via the required
VERSION argument.
Only one of manualScaling or autoScaling can be specified. If both
are specified in same yaml file, an error will be returned.
Labels cannot currently be set in the config.yaml; please use
the command-line flags to alter them.
If an option is specified both in the configuration file and via
command-line arguments, the command-line arguments override the
configuration file.
"""
).AddToParser(parser)
def _Run(args):
region = region_util.GetRegion(args)
with endpoint_util.MlEndpointOverrides(region=region):
versions_client = versions_api.VersionsClient()
operations_client = operations.OperationsClient()
version_ref = args.CONCEPTS.version.Parse()
versions_util.Update(versions_client, operations_client, version_ref, args)
log.UpdatedResource(args.version, kind='AI Platform version')
@base.ReleaseTracks(base.ReleaseTrack.GA)
class Update(base.UpdateCommand):
"""Update an AI Platform version."""
@staticmethod
def Args(parser):
_AddUpdateArgs(parser)
def Run(self, args):
return _Run(args)
@base.ReleaseTracks(base.ReleaseTrack.ALPHA, base.ReleaseTrack.BETA)
class UpdateBeta(base.UpdateCommand):
"""Update an AI Platform version."""
@staticmethod
def Args(parser):
_AddUpdateArgs(parser)
flags.AddRequestLoggingConfigFlags(parser)
def Run(self, args):
return _Run(args)