feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used for AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
SUPPORTED_REGIONS = [
'asia-east1',
'asia-northeast1',
'asia-southeast1',
'australia-southeast1',
'europe-west1',
'europe-west2',
'europe-west3',
'europe-west4',
'northamerica-northeast1',
'us-central1',
'us-east1',
'us-east4',
'us-west1',
]
SUPPORTED_REGIONS_WITH_GLOBAL = ['global'] + SUPPORTED_REGIONS

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for operating on different endpoints."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import contextlib
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from six.moves.urllib import parse
ML_API_VERSION = 'v1'
ML_API_NAME = 'ml'
def DeriveMLRegionalEndpoint(endpoint, region):
scheme, netloc, path, params, query, fragment = parse.urlparse(endpoint)
netloc = '{}-{}'.format(region, netloc)
return parse.urlunparse((scheme, netloc, path, params, query, fragment))
@contextlib.contextmanager
def MlEndpointOverrides(region=None):
"""Context manager to override the AI Platform endpoints for a while.
Args:
region: str, region of the AI Platform stack.
Yields:
None.
"""
used_endpoint = GetEffectiveMlEndpoint(region)
old_endpoint = properties.VALUES.api_endpoint_overrides.ml.Get()
try:
log.status.Print('Using endpoint [{}]'.format(used_endpoint))
if region and region != 'global':
properties.VALUES.api_endpoint_overrides.ml.Set(used_endpoint)
yield
finally:
old_endpoint = properties.VALUES.api_endpoint_overrides.ml.Set(old_endpoint)
def GetEffectiveMlEndpoint(region):
"""Returns regional ML Endpoint, or global if region not set."""
endpoint = apis.GetEffectiveApiEndpoint(ML_API_NAME, ML_API_VERSION)
if region and region != 'global':
return DeriveMLRegionalEndpoint(endpoint, region)
return endpoint

View File

@@ -0,0 +1,569 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for job submission preparation.
The main entry point is UploadPythonPackages, which takes in parameters derived
from the command line arguments and returns a list of URLs to be given to the
AI Platform API. See its docstring for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
import collections
import contextlib
import io
import os
import sys
import textwrap
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.command_lib.ml_engine import uploads
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import log
from googlecloudsdk.core.util import files
import six
from six.moves import map
DEFAULT_SETUP_FILE = """\
from setuptools import setup, find_packages
if __name__ == '__main__':
setup(
name='{package_name}',
packages=find_packages(include=['{package_name}'])
)
"""
class UploadFailureError(exceptions.Error):
"""Generic error with the packaging/upload process."""
pass
class SetuptoolsFailedError(UploadFailureError):
"""Error indicating that setuptools itself failed."""
def __init__(self, output, generated):
msg = ('Packaging of user Python code failed with message:\n\n'
'{}\n\n').format(output)
if generated:
msg += ('Try manually writing a setup.py file at your package root and '
'rerunning the command.')
else:
msg += ('Try manually building your Python code by running:\n'
' $ python setup.py sdist\n'
'and providing the output via the `--packages` flag (for '
'example, `--packages dist/package.tar.gz,dist/package2.whl)`')
super(SetuptoolsFailedError, self).__init__(msg)
class SysExecutableMissingError(UploadFailureError):
"""Error indicating that sys.executable was empty."""
def __init__(self):
super(SysExecutableMissingError, self).__init__(
textwrap.dedent("""\
No Python executable found on path. A Python executable with setuptools
installed on the PYTHONPATH is required for building AI Platform training jobs.
"""))
class MissingInitError(UploadFailureError):
"""Error indicating that the package to build had no __init__.py file."""
def __init__(self, package_dir):
super(MissingInitError, self).__init__(textwrap.dedent("""\
[{}] is not a valid Python package because it does not contain an \
`__init__.py` file. Please create one and try again. Also, please \
ensure that --package-path refers to a local directory.
""").format(package_dir))
class UncopyablePackageError(UploadFailureError):
"""Error with copying the package."""
class DuplicateEntriesError(UploadFailureError):
"""Error indicating that multiple files with the same name were provided."""
def __init__(self, duplicates):
super(DuplicateEntriesError, self).__init__(
'Cannot upload multiple packages with the same filename: [{}]'.format(
', '.join(duplicates)))
class NoStagingLocationError(UploadFailureError):
"""No staging location was provided but one was required."""
class InvalidSourceDirError(UploadFailureError):
"""Error indicating that the source directory is invalid."""
def __init__(self, source_dir):
super(InvalidSourceDirError, self).__init__(
'Source directory [{}] is not a valid directory.'.format(source_dir))
def _CopyIfNotWritable(source_dir, temp_dir):
"""Returns a writable directory with the same contents as source_dir.
If source_dir is writable, it is used. Otherwise, a directory 'dest' inside of
temp_dir is used.
Args:
source_dir: str, the directory to (potentially) copy
temp_dir: str, the path to a writable temporary directory in which to store
any copied code.
Returns:
str, the path to a writable directory with the same contents as source_dir
(i.e. source_dir, if it's writable, or a copy otherwise).
Raises:
UploadFailureError: if the command exits non-zero.
InvalidSourceDirError: if the source directory is not valid.
"""
if not os.path.isdir(source_dir):
raise InvalidSourceDirError(source_dir)
# A race condition may cause a ValueError while checking for write access
# even if the directory was valid before.
try:
writable = files.HasWriteAccessInDir(source_dir)
except ValueError:
raise InvalidSourceDirError(source_dir)
if writable:
return source_dir
if files.IsDirAncestorOf(source_dir, temp_dir):
raise UncopyablePackageError(
'Cannot copy directory since working directory [{}] is inside of '
'source directory [{}].'.format(temp_dir, source_dir))
dest_dir = os.path.join(temp_dir, 'dest')
log.debug('Copying local source tree from [%s] to [%s]', source_dir, dest_dir)
try:
files.CopyTree(source_dir, dest_dir)
except OSError:
raise UncopyablePackageError(
'Cannot write to working location [{}]'.format(dest_dir))
return dest_dir
def _GenerateSetupPyIfNeeded(setup_py_path, package_name):
"""Generates a temporary setup.py file if there is none at the given path.
Args:
setup_py_path: str, a path to the expected setup.py location.
package_name: str, the name of the Python package for which to write a
setup.py file (used in the generated file contents).
Returns:
bool, whether the setup.py file was generated.
"""
log.debug('Looking for setup.py file at [%s]', setup_py_path)
if os.path.isfile(setup_py_path):
log.info('Using existing setup.py file at [%s]', setup_py_path)
return False
setup_contents = DEFAULT_SETUP_FILE.format(package_name=package_name)
log.info('Generating temporary setup.py file:\n%s', setup_contents)
files.WriteFileContents(setup_py_path, setup_contents)
return True
@contextlib.contextmanager
def _TempDirOrBackup(default_dir):
"""Yields a temporary directory or a backup temporary directory.
Prefers creating a temporary directory (which will be cleaned up when the
context manager is closed), but falls back to default_dir. There are systems
where users can't write to temp, but we still need to copy.
Args:
default_dir: str, the backup temporary directory.
Yields:
str, the temporary directory.
"""
try:
temp_dir = files.TemporaryDirectory()
# We can't use the context manager form of files.TemporaryDirectory()
# because it makes it hard to distinguish between an OSError that occurred
# during the creation of the temporary directory and one that occurred in
# the middle of *this* context manager's body.
path = temp_dir.__enter__()
except OSError:
temp_dir = None
# Some systems don't allow access to '/tmp'
path = default_dir
try:
yield path
finally:
if temp_dir:
temp_dir.__exit__(*sys.exc_info())
class _SetupPyCommand(six.with_metaclass(abc.ABCMeta, object)):
"""A command to run setup.py in a given environment.
Includes the Python version to use and the arguments with which to run
setup.py.
Attributes:
setup_py_path: str, the path to the setup.py file
setup_py_args: list of str, the arguments with which to call setup.py
package_root: str, path to the directory containing the package to build
(must be writable, or setuptools will fail)
"""
def __init__(self, setup_py_path, setup_py_args, package_root):
self.setup_py_path = setup_py_path
self.setup_py_args = setup_py_args
self.package_root = package_root
@abc.abstractmethod
def GetArgs(self):
"""Returns arguments to use for execution (including Python executable)."""
raise NotImplementedError()
@abc.abstractmethod
def GetEnv(self):
"""Returns the environment dictionary to use for Python execution."""
raise NotImplementedError()
def Execute(self, out):
"""Run the configured setup.py command.
Args:
out: a stream to which the command output should be written.
Returns:
int, the return code of the command.
"""
return execution_utils.Exec(
self.GetArgs(),
no_exit=True, out_func=out.write, err_func=out.write,
cwd=self.package_root, env=self.GetEnv())
class _CloudSdkPythonSetupPyCommand(_SetupPyCommand):
"""A command that uses the Cloud SDK Python environment.
It uses the same OS environment, plus the same PYTHONPATH.
This is preferred, since it's more controlled.
"""
def GetArgs(self):
return execution_utils.ArgsForPythonTool(self.setup_py_path,
*self.setup_py_args,
python=GetPythonExecutable())
def GetEnv(self):
exec_env = os.environ.copy()
exec_env['PYTHONPATH'] = os.pathsep.join(sys.path)
return exec_env
class _SystemPythonSetupPyCommand(_SetupPyCommand):
"""A command that uses the system Python environment.
Uses the same executable as the Cloud SDK.
Important in case of e.g. a setup.py file that has non-stdlib dependencies.
"""
def GetArgs(self):
return [GetPythonExecutable(), self.setup_py_path] + self.setup_py_args
def GetEnv(self):
return None
def GetPythonExecutable():
python_executable = None
try:
python_executable = execution_utils.GetPythonExecutable()
except ValueError:
raise SysExecutableMissingError()
return python_executable
def _RunSetupTools(package_root, setup_py_path, output_dir):
"""Executes the setuptools `sdist` command.
Specifically, runs `python setup.py sdist` (with the full path to `setup.py`
given by setup_py_path) with arguments to put the final output in output_dir
and all possible temporary files in a temporary directory. package_root is
used as the working directory.
May attempt to run setup.py multiple times with different
environments/commands if any execution fails:
1. Using the Cloud SDK Python environment, with a full setuptools invocation
(`egg_info`, `build`, and `sdist`).
2. Using the system Python environment, with a full setuptools invocation
(`egg_info`, `build`, and `sdist`).
3. Using the Cloud SDK Python environment, with an intermediate setuptools
invocation (`build` and `sdist`).
4. Using the system Python environment, with an intermediate setuptools
invocation (`build` and `sdist`).
5. Using the Cloud SDK Python environment, with a simple setuptools
invocation which will also work for plain distutils-based setup.py (just
`sdist`).
6. Using the system Python environment, with a simple setuptools
invocation which will also work for plain distutils-based setup.py (just
`sdist`).
The reason for this order is that it prefers first the setup.py invocations
which leave the fewest files on disk. Then, we prefer the Cloud SDK execution
environment as it will be the most stable.
package_root must be writable, or setuptools will fail (there are
temporary files from setuptools that get put in the CWD).
Args:
package_root: str, the directory containing the package (that is, the
*parent* of the package itself).
setup_py_path: str, the path to the `setup.py` file to execute.
output_dir: str, path to a directory in which the built packages should be
created.
Returns:
list of str, the full paths to the generated packages.
Raises:
SysExecutableMissingError: if sys.executable is None
RuntimeError: if the execution of setuptools exited non-zero.
"""
# Unfortunately, there doesn't seem to be any easy way to move *all*
# temporary files out of the current directory, so we'll fail here if we
# can't write to it.
with _TempDirOrBackup(package_root) as working_dir:
# Simpler, but more messy (leaves artifacts on disk) command. This will work
# for both distutils- and setuputils-based setup.py files.
sdist_args = ['sdist', '--dist-dir', output_dir]
# The 'build' and 'egg_info commands (which are invoked anyways as a
# subcommands of 'sdist') are included to ensure that the fewest possible
# artifacts are left on disk.
build_args = [
'build', '--build-base', working_dir, '--build-temp', working_dir]
# Some setuptools versions don't support directly running the egg_info
# command
egg_info_args = ['egg_info', '--egg-base', working_dir]
setup_py_arg_sets = (
egg_info_args + build_args + sdist_args,
build_args + sdist_args,
sdist_args)
# See docstring for the reasoning behind this order.
setup_py_commands = []
for setup_py_args in setup_py_arg_sets:
setup_py_commands.append(_CloudSdkPythonSetupPyCommand(
setup_py_path, setup_py_args, package_root))
setup_py_commands.append(_SystemPythonSetupPyCommand(
setup_py_path, setup_py_args, package_root))
for setup_py_command in setup_py_commands:
out = io.StringIO()
return_code = setup_py_command.Execute(out)
if not return_code:
break
else:
raise RuntimeError(out.getvalue())
local_paths = [os.path.join(output_dir, rel_file)
for rel_file in os.listdir(output_dir)]
log.debug('Python packaging resulted in [%s]', ', '.join(local_paths))
return local_paths
def BuildPackages(package_path, output_dir):
"""Builds Python packages from the given package source.
That is, builds Python packages from the code in package_path, using its
parent directory (the 'package root') as its context using the setuptools
`sdist` command.
If there is a `setup.py` file in the package root, use that. Otherwise,
use a simple, temporary one made for this package.
We try to be as unobstrustive as possible (see _RunSetupTools for details):
- setuptools writes some files to the package root--we move as many temporary
generated files out of the package root as possible
- the final output gets written to output_dir
- any temporary setup.py file is written outside of the package root.
- if the current directory isn't writable, we silenly make a temporary copy
Args:
package_path: str. Path to the package. This should be the path to
the directory containing the Python code to be built, *not* its parent
(which optionally contains setup.py and other metadata).
output_dir: str, path to a long-lived directory in which the built packages
should be created.
Returns:
list of str. The full local path to all built Python packages.
Raises:
SetuptoolsFailedError: If the setup.py file fails to successfully build.
MissingInitError: If the package doesn't contain an `__init__.py` file.
InvalidSourceDirError: if the source directory is not valid.
"""
package_path = os.path.abspath(package_path)
package_root = os.path.dirname(package_path)
with _TempDirOrBackup(package_path) as working_dir:
package_root = _CopyIfNotWritable(package_root, working_dir)
if not os.path.exists(os.path.join(package_path, '__init__.py')):
# We could drop `__init__.py` in here, but it's pretty likely that this
# indicates an incorrect directory or some bigger problem and we don't
# want to obscure that.
#
# Note that we could more strictly validate here by checking each package
# in the `--module-name` argument, but this should catch most issues.
raise MissingInitError(package_path)
setup_py_path = os.path.join(package_root, 'setup.py')
package_name = os.path.basename(package_path)
generated = _GenerateSetupPyIfNeeded(setup_py_path, package_name)
try:
return _RunSetupTools(package_root, setup_py_path, output_dir)
except RuntimeError as err:
raise SetuptoolsFailedError(six.text_type(err), generated)
finally:
if generated:
# For some reason, this artifact gets generated in the package root by
# setuptools, even after setting PYTHONDONTWRITEBYTECODE or running
# `python setup.py clean --all`. It's weird to leave someone a .pyc for
# a file they never knew they had, so we clean it up.
pyc_file = os.path.join(package_root, 'setup.pyc')
for path in (setup_py_path, pyc_file):
try:
os.unlink(path)
except OSError:
log.debug(
"Couldn't remove file [%s] (it may never have been created).",
pyc_file)
def _UploadFilesByPath(paths, staging_location):
"""Uploads files after validating and transforming input type."""
if not staging_location:
raise NoStagingLocationError()
counter = collections.Counter(list(map(os.path.basename, paths)))
duplicates = [name for name, count in six.iteritems(counter) if count > 1]
if duplicates:
raise DuplicateEntriesError(duplicates)
upload_pairs = [(path, os.path.basename(path)) for path in paths]
return uploads.UploadFiles(upload_pairs, staging_location.bucket_ref,
staging_location.name)
def UploadPythonPackages(packages=(), package_path=None, staging_location=None):
"""Uploads Python packages (if necessary), building them as-specified.
An AI Platform job needs one or more Python packages to run. These Python
packages can be specified in one of three ways:
1. As a path to a local, pre-built Python package file.
2. As a path to a Cloud Storage-hosted, pre-built Python package file (paths
beginning with 'gs://').
3. As a local Python source tree (the `--package-path` flag).
In case 1, we upload the local files to Cloud Storage[1] and provide their
paths. These can then be given to the AI Platform API, which can fetch
these files.
In case 2, we don't need to do anything. We can just send these paths directly
to the AI Platform API.
In case 3, we perform a build using setuptools[2], and upload the resulting
artifacts to Cloud Storage[1]. The paths to these artifacts can be given to
the AI Platform API. See the `BuildPackages` method.
These methods of specifying Python packages may be combined.
[1] Uploads are to a specially-prefixed location in a user-provided Cloud
Storage staging bucket. If the user provides bucket `gs://my-bucket/`, a file
`package.tar.gz` is uploaded to
`gs://my-bucket/<job name>/<checksum>/package.tar.gz`.
[2] setuptools must be installed on the local user system.
Args:
packages: list of str. Path to extra tar.gz packages to upload, if any. If
empty, a package_path must be provided.
package_path: str. Relative path to source directory to be built, if any. If
omitted, one or more packages must be provided.
staging_location: storage_util.ObjectReference. Cloud Storage prefix to
which archives are uploaded. Not necessary if only remote packages are
given.
Returns:
list of str. Fully qualified Cloud Storage URLs (`gs://..`) from uploaded
packages.
Raises:
ValueError: If packages is empty, and building package_path produces no
tar archives.
SetuptoolsFailedError: If the setup.py file fails to successfully build.
MissingInitError: If the package doesn't contain an `__init__.py` file.
DuplicateEntriesError: If multiple files with the same name were provided.
ArgumentError: if no packages were found in the given path or no
staging_location was but uploads were required.
"""
remote_paths = []
local_paths = []
for package in packages:
if storage_util.ObjectReference.IsStorageUrl(package):
remote_paths.append(package)
else:
local_paths.append(package)
if package_path:
package_root = os.path.dirname(os.path.abspath(package_path))
with _TempDirOrBackup(package_root) as working_dir:
local_paths.extend(BuildPackages(package_path,
os.path.join(working_dir, 'output')))
remote_paths.extend(_UploadFilesByPath(local_paths, staging_location))
elif local_paths:
# Can't combine this with above because above requires the temporary
# directory to still be around
remote_paths.extend(_UploadFilesByPath(local_paths, staging_location))
return remote_paths
def GetStagingLocation(job_id=None, staging_bucket=None, job_dir=None):
"""Get the appropriate staging location for the job given the arguments."""
staging_location = None
if staging_bucket:
staging_location = storage_util.ObjectReference.FromBucketRef(
staging_bucket, job_id)
elif job_dir:
staging_location = storage_util.ObjectReference.FromName(
job_dir.bucket, '/'.join([f for f in [job_dir.name.rstrip('/'),
'packages'] if f]))
return staging_location

View File

@@ -0,0 +1,531 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ml-engine jobs command code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import exceptions
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.command_lib.logs import stream
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_prep
from googlecloudsdk.command_lib.ml_engine import log_utils
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.resource import resource_printer
import six
_CONSOLE_URL = ('https://console.cloud.google.com/mlengine/jobs/{job_id}?'
'project={project}')
_LOGS_URL = ('https://console.cloud.google.com/logs?'
'resource=ml_job%2Fjob_id%2F{job_id}&project={project}')
JOB_FORMAT = 'yaml(jobId,state,startTime.date(tz=LOCAL),endTime.date(tz=LOCAL))'
# Check every 10 seconds if the job is complete (if we didn't fetch any logs the
# last time)
_CONTINUE_INTERVAL = 10
_TEXT_FILE_URL = ('https://www.tensorflow.org/guide/datasets'
'#consuming_text_data')
_JSON_FILE_URL = ('https://cloud.google.com/ai-platform/prediction/docs/'
'overview#batch_prediction_input_data')
_TF_RECORD_URL = ('https://www.tensorflow.org/guide/datasets'
'#consuming_tfrecord_data')
_PREDICTION_DATA_FORMAT_MAPPER = arg_utils.ChoiceEnumMapper(
'--data-format',
jobs.GetMessagesModule(
).GoogleCloudMlV1PredictionInput.DataFormatValueValuesEnum,
custom_mappings={
'TEXT': ('text',
('Text and JSON files; for text files, see {}, for JSON files,'
' see {}'.format(_TEXT_FILE_URL, _JSON_FILE_URL))),
'TF_RECORD': ('tf-record',
'TFRecord files; see {}'.format(_TF_RECORD_URL)),
'TF_RECORD_GZIP': ('tf-record-gzip',
'GZIP-compressed TFRecord files.')
},
help_str='Data format of the input files.',
required=True)
_ACCELERATOR_MAP = arg_utils.ChoiceEnumMapper(
'--accelerator-type',
jobs.GetMessagesModule(
).GoogleCloudMlV1AcceleratorConfig.TypeValueValuesEnum,
custom_mappings={
'NVIDIA_TESLA_K80': ('nvidia-tesla-k80', 'NVIDIA Tesla K80 GPU'),
'NVIDIA_TESLA_P100': ('nvidia-tesla-p100', 'NVIDIA Tesla P100 GPU.')
},
help_str='The available types of accelerators.',
required=True)
_SCALE_TIER_CHOICES = {
'BASIC': ('basic', ('Single worker instance. This tier is suitable for '
'learning how to use AI Platform, and for '
'experimenting with new models using small datasets.')),
'STANDARD_1': ('standard-1', 'Many workers and a few parameter servers.'),
'PREMIUM_1': ('premium-1',
'Large number of workers with many parameter servers.'),
'BASIC_GPU': ('basic-gpu', 'Single worker instance with a GPU.'),
'BASIC_TPU': ('basic-tpu', 'Single worker instance with a Cloud TPU.'),
'CUSTOM': ('custom', """\
CUSTOM tier is not a set tier, but rather enables you to use your own
cluster specification. When you use this tier, set values to configure your
processing cluster according to these guidelines (using the `--config` flag):
* You _must_ set `TrainingInput.masterType` to specify the type of machine to
use for your master node. This is the only required setting.
* You _may_ set `TrainingInput.workerCount` to specify the number of workers to
use. If you specify one or more workers, you _must_ also set
`TrainingInput.workerType` to specify the type of machine to use for your
worker nodes.
* You _may_ set `TrainingInput.parameterServerCount` to specify the number of
parameter servers to use. If you specify one or more parameter servers, you
_must_ also set `TrainingInput.parameterServerType` to specify the type of
machine to use for your parameter servers. Note that all of your workers must
use the same machine type, which can be different from your parameter server
type and master type. Your parameter servers must likewise use the same
machine type, which can be different from your worker type and master type.\
""")
}
_TRAINING_SCALE_TIER_MAPPER = arg_utils.ChoiceEnumMapper(
'--scale-tier',
jobs.GetMessagesModule()
.GoogleCloudMlV1TrainingInput.ScaleTierValueValuesEnum,
custom_mappings=_SCALE_TIER_CHOICES,
help_str=('Specify the machine types, the number of replicas for workers, '
'and parameter servers.'),
default=None)
class TrainingCustomInputServerConfig(object):
"""Data class for passing custom server config for training job input."""
def __init__(self,
runtime_version,
scale_tier,
master_machine_type=None,
master_image_uri=None,
master_accelerator_type=None,
master_accelerator_count=None,
parameter_machine_type=None,
parameter_machine_count=None,
parameter_image_uri=None,
parameter_accelerator_type=None,
parameter_accelerator_count=None,
tpu_tf_version=None,
worker_machine_type=None,
worker_machine_count=None,
worker_image_uri=None,
work_accelerator_type=None,
work_accelerator_count=None,
use_chief_in_tf_config=None):
self.master_image_uri = master_image_uri
self.master_machine_type = master_machine_type
self.master_accelerator_type = master_accelerator_type
self.master_accelerator_count = master_accelerator_count
self.parameter_machine_type = parameter_machine_type
self.parameter_machine_count = parameter_machine_count
self.parameter_image_uri = parameter_image_uri
self.parameter_accelerator_type = parameter_accelerator_type
self.parameter_accelerator_count = parameter_accelerator_count
self.tpu_tf_version = tpu_tf_version
self.worker_machine_type = worker_machine_type
self.worker_machine_count = worker_machine_count
self.worker_image_uri = worker_image_uri
self.work_accelerator_type = work_accelerator_type
self.work_accelerator_count = work_accelerator_count
self.runtime_version = runtime_version
self.scale_tier = scale_tier
self.use_chief_in_tf_config = use_chief_in_tf_config
def ValidateConfig(self):
"""Validate that custom config parameters are set correctly."""
if self.master_image_uri and self.runtime_version:
raise flags.ArgumentError('Only one of --master-image-uri,'
' --runtime-version can be set.')
if self.scale_tier and self.scale_tier.name == 'CUSTOM':
if not self.master_machine_type:
raise flags.ArgumentError('--master-machine-type is required if '
'scale-tier is set to `CUSTOM`.')
return True
def GetFieldMap(self):
"""Return a mapping of object fields to apitools message fields."""
return {
'masterConfig': {
'imageUri': self.master_image_uri,
'acceleratorConfig': {
'count': self.master_accelerator_count,
'type': self.master_accelerator_type
}
},
'masterType': self.master_machine_type,
'parameterServerConfig': {
'imageUri': self.parameter_image_uri,
'acceleratorConfig': {
'count': self.parameter_accelerator_count,
'type': self.parameter_accelerator_type
}
},
'parameterServerCount': self.parameter_machine_count,
'parameterServerType': self.parameter_machine_type,
'workerConfig': {
'imageUri': self.worker_image_uri,
'acceleratorConfig': {
'count': self.work_accelerator_count,
'type': self.work_accelerator_type
},
'tpuTfVersion': self.tpu_tf_version
},
'workerCount': self.worker_machine_count,
'workerType': self.worker_machine_type,
'useChiefInTfConfig': self.use_chief_in_tf_config,
}
@classmethod
def FromArgs(cls, args, support_tpu_tf_version=False):
"""Build TrainingCustomInputServerConfig from argparse.Namespace."""
tier = args.scale_tier
if not tier:
if args.config:
data = yaml.load_path(args.config)
tier = data.get('trainingInput', {}).get('scaleTier', None)
parsed_tier = ScaleTierFlagMap().GetEnumForChoice(tier)
return cls(
scale_tier=parsed_tier,
runtime_version=args.runtime_version,
master_machine_type=args.master_machine_type,
master_image_uri=args.master_image_uri,
master_accelerator_type=(args.master_accelerator.get('type')
if args.master_accelerator else None),
master_accelerator_count=(args.master_accelerator.get('count')
if args.master_accelerator else None),
parameter_machine_type=args.parameter_server_machine_type,
parameter_machine_count=args.parameter_server_count,
parameter_image_uri=args.parameter_server_image_uri,
parameter_accelerator_type=args.parameter_server_accelerator.get('type')
if args.parameter_server_accelerator else None,
parameter_accelerator_count=args.parameter_server_accelerator.get(
'count') if args.parameter_server_accelerator else None,
tpu_tf_version=args.tpu_tf_version if support_tpu_tf_version else None,
worker_machine_type=args.worker_machine_type,
worker_machine_count=args.worker_count,
worker_image_uri=args.worker_image_uri,
work_accelerator_type=(args.worker_accelerator.get('type')
if args.worker_accelerator else None),
work_accelerator_count=(args.worker_accelerator.get('count')
if args.worker_accelerator else None),
use_chief_in_tf_config=args.use_chief_in_tf_config)
def DataFormatFlagMap():
"""Return the ChoiceEnumMapper for the --data-format flag."""
return _PREDICTION_DATA_FORMAT_MAPPER
def AcceleratorFlagMap():
"""Return the ChoiceEnumMapper for the --accelerator-type flag."""
return _ACCELERATOR_MAP
def ScaleTierFlagMap():
"""Returns the ChoiceEnumMapper for the --scale-tier flag."""
return _TRAINING_SCALE_TIER_MAPPER
def _ParseJob(job):
return resources.REGISTRY.Parse(
job,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.jobs')
def Cancel(jobs_client, job):
job_ref = _ParseJob(job)
return jobs_client.Cancel(job_ref)
def PrintDescribeFollowUp(job_id):
project = properties.VALUES.core.project.Get()
log.status.Print(
'\nView job in the Cloud Console at:\n' +
_CONSOLE_URL.format(job_id=job_id, project=project))
log.status.Print(
'\nView logs at:\n' +
_LOGS_URL.format(job_id=job_id, project=project))
def Describe(jobs_client, job):
job_ref = _ParseJob(job)
return jobs_client.Get(job_ref)
def List(jobs_client):
project_ref = resources.REGISTRY.Parse(
properties.VALUES.core.project.Get(required=True),
collection='ml.projects')
return jobs_client.List(project_ref)
def StreamLogs(job, task_name, polling_interval,
allow_multiline_logs):
log_fetcher = stream.LogFetcher(
filters=log_utils.LogFilters(job, task_name),
polling_interval=polling_interval, continue_interval=_CONTINUE_INTERVAL,
continue_func=log_utils.MakeContinueFunction(job))
return log_utils.SplitMultiline(
log_fetcher.YieldLogs(), allow_multiline=allow_multiline_logs)
_FOLLOW_UP_MESSAGE = """\
Your job is still active. You may view the status of your job with the command
$ gcloud ai-platform jobs describe {job_id}
or continue streaming the logs with the command
$ gcloud ai-platform jobs stream-logs {job_id}\
"""
def PrintSubmitFollowUp(job_id, print_follow_up_message=True):
log.status.Print('Job [{}] submitted successfully.'.format(job_id))
if print_follow_up_message:
log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job_id))
def GetStreamLogs(asyncronous, stream_logs):
"""Return, based on the command line arguments, whether we should stream logs.
Both arguments cannot be set (they're mutually exclusive flags) and the
default is False.
Args:
asyncronous: bool, the value of the --async flag.
stream_logs: bool, the value of the --stream-logs flag.
Returns:
bool, whether to stream the logs
Raises:
ValueError: if both asyncronous and stream_logs are True.
"""
if asyncronous and stream_logs:
# Doesn't have to be a nice error; they're mutually exclusive so we should
# never get here.
raise ValueError('--async and --stream-logs cannot both be set.')
if asyncronous:
log.warning('The --async flag is deprecated, as the default behavior is to '
'submit the job asynchronously; it can be omitted. '
'For synchronous behavior, please pass --stream-logs.\n')
return stream_logs
def ParseCreateLabels(jobs_client, args):
return labels_util.ParseCreateArgs(args, jobs_client.job_class.LabelsValue)
def SubmitTraining(jobs_client,
job,
job_dir=None,
staging_bucket=None,
packages=None,
package_path=None,
scale_tier=None,
config=None,
module_name=None,
runtime_version=None,
network=None,
service_account=None,
python_version=None,
stream_logs=None,
user_args=None,
labels=None,
kms_key=None,
custom_train_server_config=None,
enable_web_access=None):
"""Submit a training job."""
region = properties.VALUES.compute.region.Get(required=True)
staging_location = jobs_prep.GetStagingLocation(
staging_bucket=staging_bucket, job_id=job,
job_dir=job_dir)
try:
uris = jobs_prep.UploadPythonPackages(
packages=packages,
package_path=package_path,
staging_location=staging_location)
except jobs_prep.NoStagingLocationError:
raise flags.ArgumentError(
'If local packages are provided, the `--staging-bucket` or '
'`--job-dir` flag must be given.')
log.debug('Using {0} as trainer uris'.format(uris))
scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
scale_tier = scale_tier_enum(scale_tier) if scale_tier else None
try:
job = jobs_client.BuildTrainingJob(
path=config,
module_name=module_name,
job_name=job,
trainer_uri=uris,
region=region,
job_dir=job_dir.ToUrl() if job_dir else None,
scale_tier=scale_tier,
user_args=user_args,
runtime_version=runtime_version,
network=network,
service_account=service_account,
python_version=python_version,
labels=labels,
kms_key=kms_key,
custom_train_server_config=custom_train_server_config,
enable_web_access=enable_web_access)
except jobs_prep.NoStagingLocationError:
raise flags.ArgumentError(
'If `--package-path` is not specified, at least one Python package '
'must be specified via `--packages`.')
project_ref = resources.REGISTRY.Parse(
properties.VALUES.core.project.Get(required=True),
collection='ml.projects')
job = jobs_client.Create(project_ref, job)
if not stream_logs:
PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
return job
else:
PrintSubmitFollowUp(job.jobId, print_follow_up_message=False)
log_fetcher = stream.LogFetcher(
filters=log_utils.LogFilters(job.jobId),
polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(),
continue_interval=_CONTINUE_INTERVAL,
continue_func=log_utils.MakeContinueFunction(job.jobId))
printer = resource_printer.Printer(log_utils.LOG_FORMAT,
out=log.err)
with execution_utils.RaisesKeyboardInterrupt():
try:
printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
except KeyboardInterrupt:
log.status.Print('Received keyboard interrupt.\n')
log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
project=project_ref.Name()))
except exceptions.HttpError as err:
log.status.Print('Polling logs failed:\n{}\n'.format(six.text_type(err)))
log.info('Failure details:', exc_info=True)
log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
project=project_ref.Name()))
job_ref = resources.REGISTRY.Parse(
job.jobId,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.jobs')
job = jobs_client.Get(job_ref)
return job
def _ValidateSubmitPredictionArgs(model_dir, version):
if model_dir and version:
raise flags.ArgumentError('`--version` cannot be set with `--model-dir`')
def SubmitPrediction(jobs_client, job,
model_dir=None, model=None, version=None,
input_paths=None, data_format=None, output_path=None,
region=None, runtime_version=None, max_worker_count=None,
batch_size=None, signature_name=None, labels=None,
accelerator_count=None, accelerator_type=None):
"""Submit a prediction job."""
_ValidateSubmitPredictionArgs(model_dir, version)
project_ref = resources.REGISTRY.Parse(
properties.VALUES.core.project.Get(required=True),
collection='ml.projects')
job = jobs_client.BuildBatchPredictionJob(
job_name=job,
model_dir=model_dir,
model_name=model,
version_name=version,
input_paths=input_paths,
data_format=data_format,
output_path=output_path,
region=region,
runtime_version=runtime_version,
max_worker_count=max_worker_count,
batch_size=batch_size,
signature_name=signature_name,
labels=labels,
accelerator_count=accelerator_count,
accelerator_type=_ACCELERATOR_MAP.GetEnumForChoice(accelerator_type)
)
PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
return jobs_client.Create(project_ref, job)
def GetSummaryFormat(job):
"""Get summary table format for an ml job resource.
Args:
job: job resource to build summary output for.
Returns:
dynamic format string for resource output.
"""
if job:
if getattr(job, 'trainingInput', False):
if getattr(job.trainingInput, 'hyperparameters', False):
return flags.GetHPTrainingJobSummary()
return flags.GetStandardTrainingJobSummary()
else:
return flags.GetPredictJobSummary()
return 'yaml' # Fallback to yaml on empty resource
def ParseUpdateLabels(client, job_ref, args):
def GetLabels():
return client.Get(job_ref).labels
return labels_util.ProcessUpdateArgsLazy(
args, client.job_class.LabelsValue, GetLabels)
def Update(jobs_client, args):
"""Update a job."""
job_ref = _ParseJob(args.job)
labels_update = ParseUpdateLabels(jobs_client, job_ref, args)
try:
return jobs_client.Patch(job_ref, labels_update)
except jobs.NoFieldsSpecifiedError:
if not any(args.IsSpecified(arg) for arg in ('update_labels',
'clear_labels',
'remove_labels')):
raise
log.status.Print('No update to perform.')
return None

View File

@@ -0,0 +1,184 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions locally.
This module will always be run within a subprocess, and therefore normal
conventions of Cloud SDK do not apply here.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import json
import sys
def eprint(*args, **kwargs):
"""Print to stderr."""
# Print is being over core.log because this is a special case as
# this is a script called by gcloud.
print(*args, file=sys.stderr, **kwargs)
VERIFY_TENSORFLOW_VERSION = ('Please verify the installed tensorflow version '
'with: "python -c \'import tensorflow; '
'print tensorflow.__version__\'".')
VERIFY_SCIKIT_LEARN_VERSION = ('Please verify the installed sklearn version '
'with: "python -c \'import sklearn; '
'print sklearn.__version__\'".')
VERIFY_XGBOOST_VERSION = ('Please verify the installed xgboost version '
'with: "python -c \'import xgboost; '
'print xgboost.__version__\'".')
def _verify_tensorflow(version):
"""Check whether TensorFlow is installed at an appropriate version."""
# Check tensorflow with a recent version is installed.
try:
# pylint: disable=g-import-not-at-top
import tensorflow.compat.v1 as tf
# pylint: enable=g-import-not-at-top
except ImportError:
eprint('Cannot import Tensorflow. Please verify '
'"python -c \'import tensorflow\'" works.')
return False
try:
if tf.__version__ < version:
eprint('Tensorflow version must be at least {} .'.format(version),
VERIFY_TENSORFLOW_VERSION)
return False
except (NameError, AttributeError) as e:
eprint('Error while getting the installed TensorFlow version: ', e,
'\n', VERIFY_TENSORFLOW_VERSION)
return False
return True
def _verify_scikit_learn(version):
"""Check whether scikit-learn is installed at an appropriate version."""
# Check scikit-learn with a recent version is installed.
try:
# pylint: disable=g-import-not-at-top
import scipy # pylint: disable=unused-variable
# pylint: enable=g-import-not-at-top
except ImportError:
eprint('Cannot import scipy, which is needed for scikit-learn. Please '
'verify "python -c \'import scipy\'" works.')
return False
try:
# pylint: disable=g-import-not-at-top
import sklearn
# pylint: enable=g-import-not-at-top
except ImportError:
eprint('Cannot import sklearn. Please verify '
'"python -c \'import sklearn\'" works.')
return False
try:
if sklearn.__version__ < version:
eprint('Scikit-learn version must be at least {} .'.format(version),
VERIFY_SCIKIT_LEARN_VERSION)
return False
except (NameError, AttributeError) as e:
eprint('Error while getting the installed scikit-learn version: ', e, '\n',
VERIFY_SCIKIT_LEARN_VERSION)
return False
return True
def _verify_xgboost(version):
"""Check whether xgboost is installed at an appropriate version."""
# Check xgboost with a recent version is installed.
try:
# pylint: disable=g-import-not-at-top
import xgboost
# pylint: enable=g-import-not-at-top
except ImportError:
eprint('Cannot import xgboost. Please verify '
'"python -c \'import xgboost\'" works.')
return False
try:
if xgboost.__version__ < version:
eprint('Xgboost version must be at least {} .'.format(version),
VERIFY_XGBOOST_VERSION)
return False
except (NameError, AttributeError) as e:
eprint('Error while getting the installed xgboost version: ', e, '\n',
VERIFY_XGBOOST_VERSION)
return False
return True
def _verify_ml_libs(framework):
"""Verifies the appropriate ML libs are installed per framework."""
if framework == 'tensorflow' and not _verify_tensorflow('1.0.0'):
sys.exit(-1)
elif framework == 'scikit_learn' and not _verify_scikit_learn('0.18.1'):
sys.exit(-1)
elif framework == 'xgboost' and not _verify_xgboost('0.6a2'):
sys.exit(-1)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model-dir', required=True, help='Path of the model.')
parser.add_argument(
'--framework',
required=False,
default=None,
help=('The ML framework used to train this version of the model. '
'If not specified, the framework will be identified based on'
' the model file name stored in the specified model-dir'))
parser.add_argument('--signature-name', required=False,
help='Tensorflow signature to select input/output map.')
args, _ = parser.parse_known_args()
if args.framework is None:
from cloud.ml.prediction import prediction_utils # pylint: disable=g-import-not-at-top
framework = prediction_utils.detect_framework(args.model_dir)
else:
framework = args.framework
if framework:
_verify_ml_libs(framework)
# We want to do this *after* we verify ml libs so the user gets a nicer
# error message.
# pylint: disable=g-import-not-at-top
from cloud.ml.prediction import prediction_lib
# pylint: enable=g-import-not-at-top
instances = []
for line in sys.stdin:
instance = json.loads(line.rstrip('\n'))
instances.append(instance)
predictions = prediction_lib.local_predict(
model_dir=args.model_dir,
instances=instances,
framework=framework,
signature_name=args.signature_name)
print(json.dumps(predictions))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,198 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running training jobs locally."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import atexit
import json
import os
import subprocess
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import encoding
from googlecloudsdk.core.util import files
from six.moves import range
def GetPrimaryNodeName():
"""Get the primary node name.
Returns:
str, the name of the primary node. If running in tensorflow 1.x,
return 'master'. If running in tensorflow 2.x, return 'chief'.
If tensorflow is not installed in local envrionment, it will return
the default name 'chief'.
Raises:
ValueError: if there is no python executable on the user system thrown by
execution_utils.GetPythonExecutable.
"""
exe_override = properties.VALUES.ml_engine.local_python.Get()
python_executable = (
exe_override or files.FindExecutableOnPath('python') or
execution_utils.GetPythonExecutable())
cmd = [python_executable,
'-c',
'import tensorflow as tf; print(tf.version.VERSION)']
with files.FileWriter(os.devnull) as f:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=f)
return_code = proc.wait()
if return_code != 0:
log.warning('''
Cannot import tensorflow under path {}. Using "chief" for cluster setting.
If this is not intended, Please check if tensorflow is installed. Please also
verify if the python path used is correct. If not, to change the python path:
use `gcloud config set ml_engine/local_python $python_path`
Eg: gcloud config set ml_engine/local_python /usr/bin/python3'''.format(
python_executable))
return 'chief'
tf_version = proc.stdout.read()
if 'decode' in dir(tf_version):
tf_version = tf_version.decode('utf-8')
if tf_version.startswith('1.'):
return 'master'
elif tf_version.startswith('2.'):
return 'chief'
log.warning(
'Unexpected tensorflow version {}, using the default primary'
' node name, aka "chief" for cluster settings'.format(tf_version))
return 'chief'
def MakeProcess(module_name,
package_root,
args=None,
cluster=None,
task_type=None,
index=None,
**extra_popen_args):
"""Make a Popen object that runs the module, with the correct env.
If task_type is primary instead replaces the current process with the
subprocess via execution_utils.Exec
Args:
module_name: str. Name of the module to run, e.g. trainer.task
package_root: str. Absolute path to the package root for the module.
used as CWD for the subprocess.
args: [str]. Additional user args. Any relative paths will not work.
cluster: dict. Cluster configuration dictionary. Suitable for passing to
tf.train.ClusterSpec.
task_type: str. Task type of this process. Only relevant if cluster is
specified.
index: int. Task index of this process.
**extra_popen_args: extra args passed to Popen. Used for testing.
Returns:
a subprocess.Popen object corresponding to the subprocesses or an int
corresponding to the return value of the subprocess
(if task_type is primary)
Raises:
ValueError: if there is no python executable on the user system thrown by
execution_utils.GetPythonExecutable.
"""
if args is None:
args = []
exe_override = properties.VALUES.ml_engine.local_python.Get()
python_executable = (
exe_override or files.FindExecutableOnPath('python') or
execution_utils.GetPythonExecutable())
cmd = [python_executable, '-m', module_name] + args
config = {
'job': {'job_name': module_name, 'args': args},
'task': {'type': task_type, 'index': index} if cluster else {},
'cluster': cluster or {},
'environment': 'cloud'
}
log.info(('launching training process:\n'
'command: {cmd}\n config: {config}').format(
cmd=' '.join(cmd),
config=json.dumps(config, indent=2, sort_keys=True)))
env = os.environ.copy()
# the tf_config environment variable is used to pass the tensorflow
# configuration options to the training module. the module specific
# arguments are passed as command line arguments.
env['TF_CONFIG'] = json.dumps(config)
if task_type == GetPrimaryNodeName():
return execution_utils.Exec(
cmd, env=env, no_exit=True, cwd=package_root, **extra_popen_args)
else:
env = encoding.EncodeEnv(env)
task = subprocess.Popen(
cmd,
env=env,
cwd=package_root,
**extra_popen_args
)
atexit.register(execution_utils.KillSubprocess, task)
return task
def RunDistributed(module_name,
package_root,
num_ps,
num_workers,
num_evaluators,
start_port,
user_args=None):
"""Create a cluster configuration and start processes for the cluster.
Args:
module_name: str. Python module to use as the task.
package_root: str. Absolute path to the package root of the module.
num_ps: int. Number of parameter servers
num_workers: int. Number of workers.
num_evaluators: int. Number of evaluators.
start_port: int. First port for the contiguous block of ports used
by the cluster.
user_args: [str]. Additional user args for the task. Any relative paths will
not work.
Returns:
int. the retval of primary subprocess
"""
ports = list(range(start_port, start_port + num_ps + num_workers + 1))
cluster = {
GetPrimaryNodeName(): ['localhost:{port}'.format(port=ports[0])],
'ps': ['localhost:{port}'.format(port=p)
for p in ports[1:num_ps + 1]],
'worker': ['localhost:{port}'.format(port=p)
for p in ports[num_ps + 1:]]
}
for task_type, addresses in cluster.items():
if task_type != GetPrimaryNodeName():
for i in range(len(addresses)):
MakeProcess(module_name,
package_root,
args=user_args,
task_type=task_type,
index=i,
cluster=cluster)
for i in range(num_evaluators):
MakeProcess(module_name,
package_root,
args=user_args,
task_type='evaluator',
index=i,
cluster=cluster)
return MakeProcess(module_name,
package_root,
args=user_args,
task_type=GetPrimaryNodeName(),
index=0,
cluster=cluster)

View File

@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for local ml-engine operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
import os
import subprocess
from googlecloudsdk.command_lib.ml_engine import local_predict
from googlecloudsdk.command_lib.ml_engine import predict_utilities
from googlecloudsdk.core import config
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.util import encoding
from googlecloudsdk.core.util import files
class InvalidInstancesFileError(core_exceptions.Error):
pass
class LocalPredictRuntimeError(core_exceptions.Error):
"""Indicates that some error happened within local_predict."""
pass
class LocalPredictEnvironmentError(core_exceptions.Error):
"""Indicates that some error happened within local_predict."""
pass
class InvalidReturnValueError(core_exceptions.Error):
"""Indicates that the return value from local_predict has some error."""
pass
def RunPredict(model_dir, json_request=None, json_instances=None,
text_instances=None, framework='tensorflow',
signature_name=None):
"""Run ML Engine local prediction."""
instances = predict_utilities.ReadInstancesFromArgs(json_request,
json_instances,
text_instances)
sdk_root = config.Paths().sdk_root
if not sdk_root:
raise LocalPredictEnvironmentError(
'You must be running an installed Cloud SDK to perform local '
'prediction.')
# Inheriting the environment preserves important variables in the child
# process. In particular, LD_LIBRARY_PATH under linux and PATH under windows
# could be used to point to non-standard install locations of CUDA and CUDNN.
# If not inherited, the child process could fail to initialize Tensorflow.
env = os.environ.copy()
encoding.SetEncodedValue(env, 'CLOUDSDK_ROOT', sdk_root)
# We want to use whatever the user's Python was, before the Cloud SDK started
# changing the PATH. That's where Tensorflow is installed.
python_executables = files.SearchForExecutableOnPath('python')
# Need to ensure that ml_sdk is in PYTHONPATH for the import in
# local_predict to succeed.
orig_py_path = encoding.GetEncodedValue(env, 'PYTHONPATH') or ''
if orig_py_path:
orig_py_path = ':' + orig_py_path
encoding.SetEncodedValue(
env, 'PYTHONPATH',
os.path.join(sdk_root, 'lib', 'third_party', 'ml_sdk') + orig_py_path)
if not python_executables:
# This doesn't have to be actionable because things are probably beyond help
# at this point.
raise LocalPredictEnvironmentError(
'Something has gone really wrong; we can\'t find a valid Python '
'executable on your PATH.')
# Use python found on PATH or local_python override if set
python_executable = (properties.VALUES.ml_engine.local_python.Get() or
python_executables[0])
predict_args = ['--model-dir', model_dir, '--framework', framework]
if signature_name:
predict_args += ['--signature-name', signature_name]
# Start local prediction in a subprocess.
args = [encoding.Encode(a) for a in
([python_executable, local_predict.__file__] + predict_args)]
proc = subprocess.Popen(
args,
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
env=env)
# Pass the instances to the process that actually runs local prediction.
for instance in instances:
proc.stdin.write((json.dumps(instance) + '\n').encode('utf-8'))
proc.stdin.flush()
# Get the results for the local prediction.
output, err = proc.communicate()
if proc.returncode != 0:
raise LocalPredictRuntimeError(err)
if err:
log.warning(err)
try:
return json.loads(encoding.Decode(output))
except ValueError:
raise InvalidReturnValueError('The output for prediction is not '
'in JSON format: ' + output)

View File

@@ -0,0 +1,178 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for the ml-engine client to use command_lib.logs.stream."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
from apitools.base.py import encoding
from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six
LOG_FORMAT = ('value('
'severity,'
'timestamp.date("%Y-%m-%d %H:%M:%S %z",tz="LOCAL"), '
'task_name,'
'trial_id,'
'message'
')'
)
# TODO(b/36057459): Remove ml_job after transition from ml_job to cloudml_job is
# done. See b/34459608.
def LogFilters(job_id, task_name=None):
"""Returns filters for log fetcher to use.
Args:
job_id: String id of job.
task_name: String name of task.
Returns:
A list of filters to be passed to the logging API.
"""
filters = [
'(resource.type="ml_job" OR resource.type="cloudml_job")',
'resource.labels.job_id="{0}"'.format(job_id)
]
if task_name:
filters.append(
'(resource.labels.task_name="{0}" OR labels.task_name="{0}")'.format(
task_name))
return filters
def MakeContinueFunction(job_id):
"""Returns a function to decide if log fetcher should continue polling.
Args:
job_id: String id of job.
Returns:
A one-argument function decides if log fetcher should continue.
"""
jobs_client = jobs.JobsClient()
project_id = properties.VALUES.core.project.Get(required=True)
job_ref = resources.REGISTRY.Create(
'ml.projects.jobs', jobsId=job_id, projectsId=project_id)
def ShouldContinue(periods_without_logs):
"""Returns whether to continue polling the logs.
Returns False only once we've checked the job and it is finished; we only
check whether the job is finished once we've gone >1 interval without
getting any new logs.
Args:
periods_without_logs: integer number of empty polls.
Returns:
True if we haven't tried polling more than once or if job is not finished.
"""
if periods_without_logs <= 1:
return True
return jobs_client.Get(job_ref).endTime is None
return ShouldContinue
def SplitMultiline(log_generator, allow_multiline=False):
"""Splits the dict output of logs into multiple lines.
Args:
log_generator: iterator that returns a an ml log in dict format.
allow_multiline: Tells us if logs with multiline messages are okay or not.
Yields:
Single-line ml log dictionaries.
"""
for log in log_generator:
log_dict = _EntryToDict(log)
if allow_multiline:
yield log_dict
else:
messages = log_dict['message'].splitlines()
if not messages:
messages = ['']
for message in messages:
single_line_log = copy.deepcopy(log_dict)
single_line_log['message'] = message
yield single_line_log
def _EntryToDict(log_entry):
"""Converts a log entry to a dictionary."""
output = {}
output[
'severity'] = log_entry.severity.name if log_entry.severity else 'DEFAULT'
output['timestamp'] = log_entry.timestamp
label_attributes = _GetLabelAttributes(log_entry)
output['task_name'] = label_attributes['task_name']
if 'trial_id' in label_attributes:
output['trial_id'] = label_attributes['trial_id']
output['message'] = ''
if log_entry.jsonPayload is not None:
json_data = _ToDict(log_entry.jsonPayload)
# 'message' contains a free-text message that we want to pull out of the
# JSON.
if 'message' in json_data:
if json_data['message']:
output['message'] += json_data['message']
del json_data['message']
# Don't put 'levelname' in the JSON, since it duplicates the
# information in log_entry.severity.name
if 'levelname' in json_data:
del json_data['levelname']
output['json'] = json_data
elif log_entry.textPayload is not None:
output['message'] += six.text_type(log_entry.textPayload)
elif log_entry.protoPayload is not None:
output['json'] = encoding.MessageToDict(log_entry.protoPayload)
return output
def _GetLabelAttributes(log_entry):
"""Reads the label attributes of the given log entry."""
label_attributes = {'task_name': 'unknown_task'}
labels = _ToDict(log_entry.labels)
resource_labels = {} if not log_entry.resource else _ToDict(
log_entry.resource.labels)
if resource_labels.get('task_name') is not None:
label_attributes['task_name'] = resource_labels['task_name']
elif labels.get('task_name') is not None:
label_attributes['task_name'] = labels['task_name']
elif labels.get('ml.googleapis.com/task_name') is not None:
label_attributes['task_name'] = labels['ml.googleapis.com/task_name']
if labels.get('trial_id') is not None:
label_attributes['trial_id'] = labels['trial_id']
elif labels.get('ml.googleapis.com/trial_id') is not None:
label_attributes['trial_id'] = labels['ml.googleapis.com/trial_id']
return label_attributes
def _ToDict(message):
if not message:
return {}
if isinstance(message, dict):
return message
else:
return encoding.MessageToDict(message)

View File

@@ -0,0 +1,194 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for ml-engine models commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.command_lib.ml_engine import region_util
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.console import console_io
MODELS_COLLECTION = 'ml.projects.models'
def ParseModel(model):
"""Parses a model ID into a model resource object."""
return resources.REGISTRY.Parse(
model,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection=MODELS_COLLECTION)
def ParseCreateLabels(models_client, args):
return labels_util.ParseCreateArgs(
args, models_client.messages.GoogleCloudMlV1Model.LabelsValue)
class RegionArgError(core_exceptions.Error):
"""Indicates that both --region and --regions flag were passed."""
pass
def GetModelRegion(args):
"""Extract the region from the command line args.
Args:
args: arguments from parser.
Returns:
region, model_regions
region: str, regional endpoint or global endpoint.
model_regions: list, region where the model will be deployed.
Raises:
RegionArgError: if both --region and --regions are specified.
"""
if args.IsSpecified('region') and args.IsSpecified('regions'):
raise RegionArgError('Only one of --region or --regions can be specified.')
if args.IsSpecified('regions'):
return 'global', args.regions
if args.IsSpecified('region') and args.region != 'global':
return args.region, [args.region]
region = region_util.GetRegion(args)
if region != 'global':
return region, [region]
log.warning(
'To specify a region where the model will deployed on the global '
'endpoint, please use `--regions` and do not specify `--region`. '
'Using [us-central1] by default on https://ml.googleapis.com. '
'Please note that your model will be inaccessible from '
'https://us-central1-ml.googelapis.com\n'
'\n'
'Learn more about regional endpoints and see a list of available '
'regions: https://cloud.google.com/ai-platform/prediction/docs/'
'regional-endpoints')
return 'global', ['us-central1']
def Create(models_client, model, regions, enable_logging=None,
enable_console_logging=None, labels=None, description=None):
return models_client.Create(model, regions, enable_logging=enable_logging,
enable_console_logging=enable_console_logging,
labels=labels, description=description)
def Delete(models_client, operations_client, model):
console_io.PromptContinue('This will delete model [{}]...'.format(model),
cancel_on_no=True)
op = models_client.Delete(model)
return operations_client.WaitForOperation(
op, message='Deleting model [{}]'.format(model)).response
def List(models_client):
project_ref = resources.REGISTRY.Parse(
properties.VALUES.core.project.GetOrFail(),
collection='ml.projects')
return models_client.List(project_ref)
def ParseUpdateLabels(models_client, args):
def GetLabels():
return models_client.Get(args.model).labels
return labels_util.ProcessUpdateArgsLazy(
args, models_client.messages.GoogleCloudMlV1Model.LabelsValue, GetLabels)
def Update(models_client, operations_client, args):
model_ref = ParseModel(args.model)
labels_update = ParseUpdateLabels(models_client, args)
try:
op = models_client.Patch(model_ref, labels_update,
description=args.description)
except models.NoFieldsSpecifiedError:
if not any(args.IsSpecified(arg) for arg in ('update_labels',
'clear_labels',
'remove_labels',
'description')):
raise
log.status.Print('No update to perform.')
return None
else:
return operations_client.WaitForOperation(
op, message='Updating model [{}]'.format(args.model)).response
def GetIamPolicy(models_client, model):
model_ref = ParseModel(model)
return models_client.GetIamPolicy(model_ref)
def SetIamPolicy(models_client, model, policy_file):
model_ref = ParseModel(model)
policy, update_mask = iam_util.ParsePolicyFileWithUpdateMask(
policy_file, models_client.messages.GoogleIamV1Policy)
iam_util.LogSetIamPolicy(model_ref.Name(), 'model')
return models_client.SetIamPolicy(model_ref, policy, update_mask)
def AddIamPolicyBinding(models_client, model, member, role):
model_ref = ParseModel(model)
policy = models_client.GetIamPolicy(model_ref)
iam_util.AddBindingToIamPolicy(models_client.messages.GoogleIamV1Binding,
policy, member, role)
return models_client.SetIamPolicy(model_ref, policy, 'bindings,etag')
def RemoveIamPolicyBinding(models_client, model, member, role):
model_ref = ParseModel(model)
policy = models_client.GetIamPolicy(model_ref)
iam_util.RemoveBindingFromIamPolicy(policy, member, role)
ret = models_client.SetIamPolicy(model_ref, policy, 'bindings,etag')
iam_util.LogSetIamPolicy(model_ref.Name(), 'model')
return ret
def AddIamPolicyBindingWithCondition(models_client, model, member, role,
condition):
"""Adds IAM binding with condition to ml engine model's IAM policy."""
model_ref = ParseModel(model)
policy = models_client.GetIamPolicy(model_ref)
iam_util.AddBindingToIamPolicyWithCondition(
models_client.messages.GoogleIamV1Binding,
models_client.messages.GoogleTypeExpr,
policy,
member,
role,
condition)
return models_client.SetIamPolicy(model_ref, policy, 'bindings,etag')
def RemoveIamPolicyBindingWithCondition(models_client, model, member, role,
condition):
model_ref = ParseModel(model)
policy = models_client.GetIamPolicy(model_ref)
iam_util.RemoveBindingFromIamPolicyWithCondition(policy, member, role,
condition)
ret = models_client.SetIamPolicy(model_ref, policy, 'bindings,etag')
iam_util.LogSetIamPolicy(model_ref.Name(), 'model')
return ret

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for ml-engine operations commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def Cancel(operations_client, operation):
operation_ref = resources.REGISTRY.Parse(
operation,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.operations')
return operations_client.Cancel(operation_ref)
def Describe(operations_client, operation):
operation_ref = resources.REGISTRY.Parse(
operation,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.operations')
return operations_client.Get(operation_ref)
def List(operations_client):
project_ref = resources.REGISTRY.Parse(
properties.VALUES.core.project.GetOrFail(),
collection='ml.projects')
return operations_client.List(project_ref)
def Wait(operations_client, operation):
operation_ref = resources.REGISTRY.Parse(
operation,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.operations')
operation = operations_client.Get(operation_ref)
return operations_client.WaitForOperation(operation)

View File

@@ -0,0 +1,228 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for reading instances for prediction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import io
import json
from googlecloudsdk.api_lib.ml_engine import models
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.console import console_io
from googlecloudsdk.core.util import encoding
import six
class InvalidInstancesFileError(core_exceptions.Error):
"""Indicates that the input file was invalid in some way."""
pass
def ReadRequest(input_file):
"""Reads a JSON request from the specified input file.
Args:
input_file: An open file-like object for the input file.
Returns:
A list of instances.
Raises:
InvalidInstancesFileError: If the input file is invalid.
"""
# `json.loads doesn't always work with binary / UTF-8 data in
# Python 3.5, so we'll read the file and use `json.loads` instead.
contents = input_file.read()
if isinstance(contents, six.binary_type):
# Handle UTF8-BOM
contents = encoding.Decode(contents, encoding='utf-8-sig')
try:
request = json.loads(contents)
except ValueError:
raise InvalidInstancesFileError(
'Input instances are not in JSON format. '
'See "gcloud ml-engine predict --help" for details.')
if 'instances' not in request:
raise InvalidInstancesFileError(
'Invalid JSON request: missing "instances" attribute')
instances = request['instances']
if not isinstance(instances, list):
raise InvalidInstancesFileError(
'Invalid JSON request: "instances" must be a list')
return instances
def ReadInstances(input_file, data_format, limit=None):
"""Reads the instances from input file.
Args:
input_file: An open file-like object for the input file.
data_format: str, data format of the input file, 'json' or 'text'.
limit: int, the maximum number of instances allowed in the file
Returns:
A list of instances.
Raises:
InvalidInstancesFileError: If the input file is invalid (invalid format or
contains too many/zero instances).
"""
instances = []
for line_num, line in enumerate(input_file):
if isinstance(line, six.binary_type):
line = encoding.Decode(line, encoding='utf-8-sig') # Handle UTF8-BOM
line_content = line.rstrip('\r\n')
if not line_content:
raise InvalidInstancesFileError('Empty line is not allowed in the '
'instances file.')
if limit and line_num >= limit:
raise InvalidInstancesFileError(
'The gcloud CLI can currently process no more than ' +
six.text_type(limit) +
' instances per file. Please use the API directly if you need to send'
' more.')
if data_format == 'json':
try:
instances.append(json.loads(line_content))
except ValueError:
raise InvalidInstancesFileError(
'Input instances are not in JSON format. '
'See "gcloud ai-platform predict --help" for details.')
elif data_format == 'text':
instances.append(line_content)
if not instances:
raise InvalidInstancesFileError(
'No valid instance was found in input file.')
return instances
def ReadInstancesFromArgs(json_request,
json_instances,
text_instances,
limit=None):
"""Reads the instances from the given file path ('-' for stdin).
Exactly one of json_request, json_instances, text_instances must be given.
Args:
json_request: str or None, a path to a file ('-' for stdin) containing
the JSON body of a prediction request.
json_instances: str or None, a path to a file ('-' for stdin) containing
instances in JSON format.
text_instances: str or None, a path to a file ('-' for stdin) containing
instances in text format.
limit: int, the maximum number of instances allowed in the file
Returns:
A list of instances.
Raises:
InvalidInstancesFileError: If the input file is invalid (invalid format or
contains too many/zero instances), or an improper combination of input
files was given.
"""
mutex_args = [json_request, json_instances, text_instances]
if len({arg for arg in mutex_args if arg}) != 1:
raise InvalidInstancesFileError(
'Exactly one of --json-request, --json-instances and --text-instances '
'must be specified.')
if json_request:
data_format = 'json_request'
input_file = json_request
if json_instances:
data_format = 'json'
input_file = json_instances
elif text_instances:
data_format = 'text'
input_file = text_instances
data = console_io.ReadFromFileOrStdin(input_file, binary=True)
with io.BytesIO(data) as f:
if data_format == 'json_request':
return ReadRequest(f)
else:
return ReadInstances(f, data_format, limit=limit)
def ParseModelOrVersionRef(model_id, version_id):
if version_id:
return resources.REGISTRY.Parse(
version_id,
collection='ml.projects.models.versions',
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'modelsId': model_id
})
else:
return resources.REGISTRY.Parse(
model_id,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.models')
def GetDefaultFormat(predictions):
if not isinstance(predictions, list):
# This usually indicates some kind of error case, so surface the full API
# response
return 'json'
elif not predictions:
return None
# predictions is guaranteed by API contract to be a list of similarly shaped
# objects, but we don't know ahead of time what those objects look like.
elif isinstance(predictions[0], dict):
keys = ', '.join(sorted(predictions[0].keys()))
return """
table(
predictions:format="table(
{}
)"
)""".format(keys)
else:
return 'table[no-heading](predictions)'
def GetRuntimeVersion(model=None, version=None):
if version:
version_ref = ParseModelOrVersionRef(model, version)
version_data = versions_api.VersionsClient().Get(version_ref)
else:
version_data = models.ModelsClient().Get(model).defaultVersion
return version_data.framework, version_data.runtimeVersion
def CheckRuntimeVersion(model=None, version=None):
"""Check if runtime-version is more than 1.8."""
framework, runtime_version = GetRuntimeVersion(model, version)
if framework == 'TENSORFLOW':
release, version = map(int, (runtime_version.split('.')))
return (release == 1 and version >= 8) or (release > 1)
else:
return False

View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for handling region flag."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.command_lib.ml_engine import constants
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core.console import console_io
def _PromptForRegion():
"""Prompt for region from list of available regions.
Returns:
The region specified by the user, str
"""
if not console_io.CanPrompt():
return None
all_regions = constants.SUPPORTED_REGIONS_WITH_GLOBAL
idx = console_io.PromptChoice(
all_regions, message=('Please specify a region:\n'
'(For the global endpoint the region needs to be '
'specified as \'global\'.)\n'),
cancel_option=True)
region = all_regions[idx]
log.status.Print(
'To make this the default region, run '
'`gcloud config set ai_platform/region {}`.\n'.format(region))
return region
def GetRegion(args):
"""Gets the region and prompt for region if not provided.
Note: region can be either `global` or one of supported regions.
Region is decided in the following order:
- region argument;
- ai_platform/region gcloud config;
- prompt user input.
Args:
args: Namespace, The args namespace.
Returns:
A str representing region.
"""
if args.IsSpecified('region'):
return args.region
if properties.VALUES.ai_platform.region.IsExplicitlySet():
return properties.VALUES.ai_platform.region.Get()
region = _PromptForRegion()
# In unit test, it's not allowed to prompt for asking the choices. Default to
# us-central1.
return region or 'us-central1'

View File

@@ -0,0 +1,33 @@
project:
name: project
collection: ml.projects
attributes:
- parameter_name: projectsId
attribute_name: project
help: The name of the Google Cloud ML Engine project.
location:
name: location
collection: ml.projects.locations
attributes:
- parameter_name: locationsId
attribute_name: location
help: The name of the Google Cloud ML Engine location.
model:
name: model
collection: ml.projects.models
attributes:
- &model
parameter_name: modelsId
attribute_name: model
help: The name of the Google Cloud ML Engine model.
version:
name: version
collection: ml.projects.models.versions
attributes:
- *model
- parameter_name: versionsId
attribute_name: version
help: The name of the Google Cloud ML Engine model version.

View File

@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common ML file upload logic."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import hashlib
import os
from googlecloudsdk.api_lib.storage import storage_api
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.core import exceptions
from googlecloudsdk.core.util import files as file_utils
import six
from six.moves import zip
# For ease of mocking in tests without messing up core Python functionality
_PATH_SEP = os.path.sep
class MissingStagingBucketException(Exception):
"""Indicates that a staging bucket was not provided with a local path.
It doesn't inherit from core.exceptions.Error because it should be caught and
re-raised at the call site with an actionable message.
"""
class BadDirectoryError(exceptions.Error):
"""Indicates that a provided directory for upload was empty."""
def UploadFiles(upload_pairs, bucket_ref, gs_prefix=None):
"""Uploads files at the local path to a specifically prefixed location.
The prefix is 'cloudmldist/<current timestamp>'.
Args:
upload_pairs: [(str, str)]. Pairs of absolute paths to local files to upload
and corresponding path in Cloud Storage (that goes after the prefix). For
example, ('/path/foo', 'bar') will upload '/path/foo' to '<prefix>/bar' in
Cloud Storage.
bucket_ref: storage_util.BucketReference.
Files will be uploaded to this bucket.
gs_prefix: str. Prefix to the GCS Path where files will be uploaded.
Returns:
[str]. A list of fully qualified gcs paths for the uploaded files, in the
same order they were provided.
"""
checksum = file_utils.Checksum(algorithm=hashlib.sha256)
for local_path, _ in upload_pairs:
checksum.AddFileContents(local_path)
if gs_prefix is not None:
gs_prefix = '/'.join([gs_prefix, checksum.HexDigest()])
else:
gs_prefix = checksum.HexDigest()
storage_client = storage_api.StorageClient()
dests = []
for local_path, uploaded_path in upload_pairs:
obj_ref = storage_util.ObjectReference.FromBucketRef(
bucket_ref, '/'.join([gs_prefix, uploaded_path]))
obj = storage_client.CopyFileToGCS(local_path, obj_ref)
dests.append('/'.join(['gs:/', obj.bucket, obj.name]))
return dests
def _GetFilesRelative(root):
"""Return all the descendents of root, relative to its path.
For instance, given the following directory structure
/path/to/root/a
/path/to/root/a/b
/path/to/root/c
This function would return `['a', 'a/b', 'c']`.
Args:
root: str, the path to list descendents of.
Returns:
list of str, the paths in the given directory.
"""
paths = []
for dirpath, _, filenames in os.walk(six.text_type(root)):
for filename in filenames:
abs_path = os.path.join(dirpath, filename)
paths.append(os.path.relpath(abs_path, start=root))
return paths
def UploadDirectoryIfNecessary(path, staging_bucket=None, gs_prefix=None):
"""Uploads path to Cloud Storage if it isn't already there.
Translates local file system paths to Cloud Storage-style paths (i.e. using
the Unix path separator '/').
Args:
path: str, the path to the directory. Can be a Cloud Storage ("gs://") path
or a local filesystem path (no protocol).
staging_bucket: storage_util.BucketReference or None. If the path is local,
the bucket to which it should be uploaded.
gs_prefix: str, prefix for the directory within the staging bucket.
Returns:
str, a Cloud Storage path where the directory has been uploaded (possibly
prior to the execution of this function).
Raises:
MissingStagingBucketException: if `path` is a local path, but staging_bucket
isn't found.
BadDirectoryError: if the given directory couldn't be found or is empty.
"""
if path.startswith('gs://'):
# The "directory" is already in Cloud Storage, so nothing needs to be done
return path
if staging_bucket is None:
# If the directory is local, a staging bucket must be provided
raise MissingStagingBucketException(
'The path provided was local, but no staging bucket for upload '
'was provided.')
if not os.path.isdir(path):
raise BadDirectoryError('[{}] is not a valid directory.'.format(path))
files = _GetFilesRelative(path)
# We want to upload files using '/' as a virtual file separator, since that's
# what Cloud Storage uses.
dests = [f.replace(_PATH_SEP, '/') for f in files]
# We put `path` back in, so that UploadFiles can actually find them.
full_files = [_PATH_SEP.join([path, f]) for f in files]
uploaded_paths = UploadFiles(list(zip(full_files, dests)),
staging_bucket,
gs_prefix=gs_prefix)
if not uploaded_paths:
raise BadDirectoryError(
'Cannot upload contents of directory [{}] to Google Cloud Storage; '
'directory has no files.'.format(path))
# Get the prefix used by removing the part that we specified from the output.
# Depends on the order of the result of UploadFiles.
return uploaded_paths[0][:-len(dests[0])]

View File

@@ -0,0 +1,274 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for ml versions commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.ml_engine import uploads
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.args import repeated
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.console import console_io
class InvalidArgumentCombinationError(exceptions.Error):
"""Indicates that a given combination of arguments was invalid."""
pass
def ParseCreateLabels(client, args):
return labels_util.ParseCreateArgs(args, client.version_class.LabelsValue)
def ParseUpdateLabels(client, get_result, args):
return labels_util.ProcessUpdateArgsLazy(
args, client.version_class.LabelsValue, get_result.GetAttrThunk('labels'))
def ParseVersion(model, version):
"""Parses a model/version ID into a version resource object."""
return resources.REGISTRY.Parse(
version,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'modelsId': model
},
collection='ml.projects.models.versions')
def WaitForOpMaybe(operations_client, op, asyncronous=False, message=None):
"""Waits for an operation if asyncronous flag is on.
Args:
operations_client: api_lib.ml_engine.operations.OperationsClient, the client
via which to poll
op: Cloud ML Engine operation, the operation to poll
asyncronous: bool, whether to wait for the operation or return immediately
message: str, the message to display while waiting for the operation
Returns:
The result of the operation if asyncronous is true, or the Operation message
otherwise
"""
if asyncronous:
return op
return operations_client.WaitForOperation(op, message=message).response
def Create(versions_client,
operations_client,
version_id,
model=None,
origin=None,
staging_bucket=None,
runtime_version=None,
config_file=None,
asyncronous=None,
labels=None,
machine_type=None,
description=None,
framework=None,
python_version=None,
prediction_class=None,
package_uris=None,
accelerator_config=None,
service_account=None,
explanation_method=None,
num_integral_steps=None,
num_paths=None,
image=None,
command=None,
container_args=None,
env_vars=None,
ports=None,
predict_route=None,
health_route=None,
min_nodes=None,
max_nodes=None,
metrics=None,
containers_hidden=True,
autoscaling_hidden=True):
"""Create a version, optionally waiting for creation to finish."""
if origin:
try:
origin = uploads.UploadDirectoryIfNecessary(origin, staging_bucket)
except uploads.MissingStagingBucketException:
raise InvalidArgumentCombinationError(
'If --origin is provided as a local path, --staging-bucket must be '
'given as well.')
if explanation_method is not None:
log.status.Print(
'Explanations reflect patterns in your model, but don\'t necessarily '
'reveal fundamental relationships about your data population. See '
'https://cloud.google.com/vertex-ai/docs/explainable-ai/limitations '
'for more information.')
model_ref = models_util.ParseModel(model)
version = versions_client.BuildVersion(
version_id,
path=config_file,
deployment_uri=origin,
runtime_version=runtime_version,
labels=labels,
description=description,
machine_type=machine_type,
framework=framework,
python_version=python_version,
package_uris=package_uris,
prediction_class=prediction_class,
accelerator_config=accelerator_config,
service_account=service_account,
explanation_method=explanation_method,
num_integral_steps=num_integral_steps,
num_paths=num_paths,
image=image,
command=command,
container_args=container_args,
env_vars=env_vars,
ports=ports,
predict_route=predict_route,
health_route=health_route,
min_nodes=min_nodes,
max_nodes=max_nodes,
metrics=metrics,
containers_hidden=containers_hidden,
autoscaling_hidden=autoscaling_hidden)
if not version.deploymentUri and containers_hidden:
raise InvalidArgumentCombinationError(
'Either `--origin` must be provided or `deploymentUri` must be '
'provided in the file given by `--config`.')
has_image = (
hasattr(version, 'container') and hasattr(version.container, 'image') and
version.container.image)
if not version.deploymentUri and not has_image and not containers_hidden:
raise InvalidArgumentCombinationError(
'Either `--origin`, `--image`, or equivalent parameters in a config '
'file (from `--config`) must be specified.')
op = versions_client.Create(model_ref, version)
return WaitForOpMaybe(
operations_client, op, asyncronous=asyncronous,
message='Creating version (this might take a few minutes)...')
def Delete(versions_client, operations_client, version, model=None):
version_ref = ParseVersion(model, version)
console_io.PromptContinue(
'This will delete version [{}]...'.format(version_ref.versionsId),
cancel_on_no=True)
op = versions_client.Delete(version_ref)
return WaitForOpMaybe(
operations_client, op, asyncronous=False,
message='Deleting version [{}]...'.format(version_ref.versionsId))
def Describe(versions_client, version, model=None):
version_ref = ParseVersion(model, version)
return versions_client.Get(version_ref)
def List(versions_client, model=None):
model_ref = models_util.ParseModel(model)
return versions_client.List(model_ref)
_ALLOWED_UPDATE_YAML_FIELDS = frozenset([
'autoScaling',
'description',
'manualScaling',
'requestLoggingConfig',
])
def Update(versions_client, operations_client, version_ref, args):
"""Update the given version."""
get_result = repeated.CachedResult.FromFunc(
versions_client.Get, version_ref)
version = None
if hasattr(args, 'config') and args.config:
version = versions_client.ReadConfig(
args.config, _ALLOWED_UPDATE_YAML_FIELDS)
description = args.description or (version.description if version else None)
# The semantics of updating/removing/clearing labels from the config file is
# not totally clear, so labels aren't currently allowed in config files.
labels_update = ParseUpdateLabels(versions_client, get_result, args)
manual_scaling_nodes = None
if version and hasattr(version.manualScaling, 'nodes'):
manual_scaling_nodes = version.manualScaling.nodes
auto_scaling_min_nodes = None
if version and hasattr(version.autoScaling, 'minNodes'):
auto_scaling_min_nodes = version.autoScaling.minNodes
auto_scaling_max_nodes = None
if version and hasattr(version.autoScaling, 'maxNodes'):
auto_scaling_max_nodes = version.autoScaling.maxNodes
bigquery_table_name = getattr(args, 'bigquery_table_name', None)
if bigquery_table_name is None and version and hasattr(
version.requestLoggingConfig, 'bigqueryTableName'):
bigquery_table_name = version.requestLoggingConfig.bigqueryTableName
sampling_percentage = getattr(args, 'sampling_percentage', None)
if sampling_percentage is None and version and hasattr(
version.requestLoggingConfig, 'samplingPercentage'):
sampling_percentage = version.requestLoggingConfig.samplingPercentage
all_args = ['update_labels', 'clear_labels', 'remove_labels', 'description']
try:
op = versions_client.Patch(
version_ref,
labels_update,
description,
manual_scaling_nodes=manual_scaling_nodes,
auto_scaling_min_nodes=auto_scaling_min_nodes,
auto_scaling_max_nodes=auto_scaling_max_nodes,
bigquery_table_name=bigquery_table_name,
sampling_percentage=sampling_percentage)
except versions_api.NoFieldsSpecifiedError:
if not any(args.IsSpecified(arg) for arg in all_args):
raise
log.status.Print('No update to perform.')
return None
else:
return operations_client.WaitForOperation(
op, message='Updating version [{}]'.format(version_ref.Name())).response
def SetDefault(versions_client, version, model=None):
version_ref = ParseVersion(model, version)
return versions_client.SetDefault(version_ref)
def ValidateFrameworkAndMachineTypeGa(framework, machine_type):
frameworks_enum = (
versions_api.GetMessagesModule().GoogleCloudMlV1Version
.FrameworkValueValuesEnum)
if (framework != frameworks_enum.TENSORFLOW and
not machine_type.startswith('ml')):
raise InvalidArgumentCombinationError(
'Machine type {0} is currently only supported with tensorflow.'.format(
machine_type))