feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,329 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with ML jobs API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import encoding
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
class NoFieldsSpecifiedError(exceptions.Error):
"""Error indicating that no updates were requested in a Patch operation."""
class NoPackagesSpecifiedError(exceptions.Error):
"""Error that no packages were specified for non-custom training."""
def GetMessagesModule(version='v1'):
return apis.GetMessagesModule('ml', version)
def GetClientInstance(version='v1', no_http=False):
return apis.GetClientInstance('ml', version, no_http=no_http)
class JobsClient(object):
"""Client for jobs service in the Cloud ML Engine API."""
def __init__(self, client=None, messages=None,
short_message_prefix='GoogleCloudMlV1', client_version='v1'):
self.client = client or GetClientInstance(client_version)
self.messages = messages or self.client.MESSAGES_MODULE
self._short_message_prefix = short_message_prefix
def GetShortMessage(self, short_message_name):
return getattr(self.messages,
'{prefix}{name}'.format(prefix=self._short_message_prefix,
name=short_message_name), None)
@property
def state_enum(self):
return self.messages.GoogleCloudMlV1Job.StateValueValuesEnum
def List(self, project_ref):
req = self.messages.MlProjectsJobsListRequest(
parent=project_ref.RelativeName())
return list_pager.YieldFromList(
self.client.projects_jobs, req, field='jobs',
batch_size_attribute='pageSize')
@property
def job_class(self):
return self.messages.GoogleCloudMlV1Job
@property
def training_input_class(self):
return self.messages.GoogleCloudMlV1TrainingInput
@property
def prediction_input_class(self):
return self.messages.GoogleCloudMlV1PredictionInput
def _MakeCreateRequest(self, parent=None, job=None):
return self.messages.MlProjectsJobsCreateRequest(
parent=parent,
googleCloudMlV1Job=job)
def Create(self, project_ref, job):
return self.client.projects_jobs.Create(
self._MakeCreateRequest(
parent=project_ref.RelativeName(),
job=job))
def Cancel(self, job_ref):
"""Cancels given job."""
req = self.messages.MlProjectsJobsCancelRequest(name=job_ref.RelativeName())
return self.client.projects_jobs.Cancel(req)
def Get(self, job_ref):
req = self.messages.MlProjectsJobsGetRequest(name=job_ref.RelativeName())
return self.client.projects_jobs.Get(req)
def Patch(self, job_ref, labels_update):
"""Update a job."""
job = self.job_class()
update_mask = []
if labels_update.needs_update:
job.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise NoFieldsSpecifiedError('No updates requested.')
req = self.messages.MlProjectsJobsPatchRequest(
name=job_ref.RelativeName(),
googleCloudMlV1Job=job,
updateMask=','.join(update_mask)
)
return self.client.projects_jobs.Patch(req)
def BuildTrainingJob(self,
path=None,
module_name=None,
job_name=None,
trainer_uri=None,
region=None,
job_dir=None,
scale_tier=None,
user_args=None,
runtime_version=None,
python_version=None,
network=None,
service_account=None,
labels=None,
kms_key=None,
custom_train_server_config=None,
enable_web_access=None):
"""Builds a Cloud ML Engine Job from a config file and/or flag values.
Args:
path: path to a yaml configuration file
module_name: value to set for moduleName field (overrides yaml file)
job_name: value to set for jobName field (overrides yaml file)
trainer_uri: List of values to set for trainerUri field (overrides yaml
file)
region: compute region in which to run the job (overrides yaml file)
job_dir: Cloud Storage working directory for the job (overrides yaml
file)
scale_tier: ScaleTierValueValuesEnum the scale tier for the job
(overrides yaml file)
user_args: [str]. A list of arguments to pass through to the job.
(overrides yaml file)
runtime_version: the runtime version in which to run the job (overrides
yaml file)
python_version: the Python version in which to run the job (overrides
yaml file)
network: user network to which the job should be peered with (overrides
yaml file)
service_account: A service account (email address string) to use for the
job.
labels: Job.LabelsValue, the Cloud labels for the job
kms_key: A customer-managed encryption key to use for the job.
custom_train_server_config: jobs_util.CustomTrainingInputServerConfig,
configuration object for custom server parameters.
enable_web_access: whether to enable the interactive shell for the job.
Raises:
NoPackagesSpecifiedError: if a non-custom job was specified without any
trainer_uris.
Returns:
A constructed Job object.
"""
job = self.job_class()
# TODO(b/123467089): Remove yaml file loading here, only parse data objects
if path:
data = yaml.load_path(path)
if data:
job = encoding.DictToMessage(data, self.job_class)
if job_name:
job.jobId = job_name
if labels is not None:
job.labels = labels
if not job.trainingInput:
job.trainingInput = self.training_input_class()
additional_fields = {
'pythonModule': module_name,
'args': user_args,
'packageUris': trainer_uri,
'region': region,
'jobDir': job_dir,
'scaleTier': scale_tier,
'runtimeVersion': runtime_version,
'pythonVersion': python_version,
'network': network,
'serviceAccount': service_account,
'enableWebAccess': enable_web_access,
}
for field_name, value in additional_fields.items():
if value is not None:
setattr(job.trainingInput, field_name, value)
if kms_key:
arg_utils.SetFieldInMessage(job,
'trainingInput.encryptionConfig.kmsKeyName',
kms_key)
if custom_train_server_config:
for field_name, value in custom_train_server_config.GetFieldMap().items():
if value is not None:
if (field_name.endswith('Config') and
not field_name.endswith('TfConfig')):
if value['imageUri']:
arg_utils.SetFieldInMessage(
job,
'trainingInput.{}.imageUri'.format(field_name),
value['imageUri'])
if value['acceleratorConfig']['type']:
arg_utils.SetFieldInMessage(
job,
'trainingInput.{}.acceleratorConfig.type'.format(field_name),
value['acceleratorConfig']['type'])
if value['acceleratorConfig']['count']:
arg_utils.SetFieldInMessage(
job,
'trainingInput.{}.acceleratorConfig.count'.format(field_name),
value['acceleratorConfig']['count'])
if field_name == 'workerConfig' and value['tpuTfVersion']:
arg_utils.SetFieldInMessage(
job,
'trainingInput.{}.tpuTfVersion'.format(field_name),
value['tpuTfVersion'])
else:
setattr(job.trainingInput, field_name, value)
if not self.HasPackageURIs(job) and not self.IsCustomContainerTraining(job):
raise NoPackagesSpecifiedError('Non-custom jobs must have packages.')
return job
def HasPackageURIs(self, job):
return bool(job.trainingInput.packageUris)
def IsCustomContainerTraining(self, job):
return bool(job.trainingInput.masterConfig and
job.trainingInput.masterConfig.imageUri)
def BuildBatchPredictionJob(self,
job_name=None,
model_dir=None,
model_name=None,
version_name=None,
input_paths=None,
data_format=None,
output_path=None,
region=None,
runtime_version=None,
max_worker_count=None,
batch_size=None,
signature_name=None,
labels=None,
accelerator_count=None,
accelerator_type=None):
"""Builds a Cloud ML Engine Job for batch prediction from flag values.
Args:
job_name: value to set for jobName field
model_dir: str, Google Cloud Storage location of the model files
model_name: str, value to set for modelName field
version_name: str, value to set for versionName field
input_paths: list of input files
data_format: format of the input files
output_path: single value for the output location
region: compute region in which to run the job
runtime_version: the runtime version in which to run the job
max_worker_count: int, the maximum number of workers to use
batch_size: int, the number of records per batch sent to Tensorflow
signature_name: str, name of input/output signature in the TF meta graph
labels: Job.LabelsValue, the Cloud labels for the job
accelerator_count: int, The number of accelerators to attach to the
machines
accelerator_type: AcceleratorsValueListEntryValuesEnum, The type of
accelerator to add to machine.
Returns:
A constructed Job object.
"""
project_id = properties.VALUES.core.project.GetOrFail()
if accelerator_type:
accelerator_config_msg = self.GetShortMessage('AcceleratorConfig')
accelerator_config = accelerator_config_msg(count=accelerator_count,
type=accelerator_type)
else:
accelerator_config = None
prediction_input = self.prediction_input_class(
inputPaths=input_paths,
outputPath=output_path,
region=region,
runtimeVersion=runtime_version,
maxWorkerCount=max_worker_count,
batchSize=batch_size,
accelerator=accelerator_config
)
prediction_input.dataFormat = prediction_input.DataFormatValueValuesEnum(
data_format)
if model_dir:
prediction_input.uri = model_dir
elif version_name:
version_ref = resources.REGISTRY.Parse(
version_name, collection='ml.projects.models.versions',
params={'modelsId': model_name, 'projectsId': project_id})
prediction_input.versionName = version_ref.RelativeName()
else:
model_ref = resources.REGISTRY.Parse(
model_name, collection='ml.projects.models',
params={'projectsId': project_id})
prediction_input.modelName = model_ref.RelativeName()
if signature_name:
prediction_input.signatureName = signature_name
return self.job_class(
jobId=job_name,
predictionInput=prediction_input,
labels=labels
)

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with ML locations API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
class NoFieldsSpecifiedError(exceptions.Error):
"""Error indicating that no updates were requested in a Patch operation."""
def _ParseLocation(location):
return resources.REGISTRY.Parse(
location,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.locations')
class LocationsClient(object):
"""High-level client for the AI Platform locations surface."""
def __init__(self, client=None, messages=None):
self.client = client or apis.GetClientInstance('ml', 'v1')
self.messages = messages or self.client.MESSAGES_MODULE
def Get(self, location):
"""Get details about a location."""
location_ref = _ParseLocation(location)
req = self.messages.MlProjectsLocationsGetRequest(
name=location_ref.RelativeName())
return self.client.projects_locations.Get(req)
def List(self, project_ref):
"""List available locations for the project."""
req = self.messages.MlProjectsLocationsListRequest(
parent=project_ref.RelativeName())
return list_pager.YieldFromList(
self.client.projects_locations,
req,
field='locations',
batch_size_attribute='pageSize')

View File

@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with ML models API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
class NoFieldsSpecifiedError(exceptions.Error):
"""Error indicating that no updates were requested in a Patch operation."""
def _ParseModel(model_id):
return resources.REGISTRY.Parse(
model_id,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.models')
class ModelsClient(object):
"""High-level client for the ML models surface."""
def __init__(self, client=None, messages=None):
self.client = client or apis.GetClientInstance('ml', 'v1')
self.messages = messages or self.client.MESSAGES_MODULE
def Create(self, model_name, regions, enable_logging=False,
enable_console_logging=False, labels=None, description=None):
"""Create a new model."""
model_ref = _ParseModel(model_name)
regions_list = regions or []
project_ref = resources.REGISTRY.Parse(model_ref.projectsId,
collection='ml.projects')
req = self.messages.MlProjectsModelsCreateRequest(
parent=project_ref.RelativeName(),
googleCloudMlV1Model=self.messages.GoogleCloudMlV1Model(
name=model_ref.Name(),
regions=regions_list,
onlinePredictionLogging=enable_logging,
onlinePredictionConsoleLogging=enable_console_logging,
description=description,
labels=labels))
return self.client.projects_models.Create(req)
def GetIamPolicy(self, model_ref):
return self.client.projects_models.GetIamPolicy(
self.messages.MlProjectsModelsGetIamPolicyRequest(
resource=model_ref.RelativeName()))
def SetIamPolicy(self, model_ref, policy, update_mask):
request = self.messages.GoogleIamV1SetIamPolicyRequest(
policy=policy,
updateMask=update_mask)
return self.client.projects_models.SetIamPolicy(
self.messages.MlProjectsModelsSetIamPolicyRequest(
googleIamV1SetIamPolicyRequest=request,
resource=model_ref.RelativeName()))
def Delete(self, model):
"""Delete an existing model."""
model_ref = _ParseModel(model)
req = self.messages.MlProjectsModelsDeleteRequest(
name=model_ref.RelativeName())
return self.client.projects_models.Delete(req)
def Get(self, model):
"""Get details about a model."""
model_ref = _ParseModel(model)
req = self.messages.MlProjectsModelsGetRequest(
name=model_ref.RelativeName())
return self.client.projects_models.Get(req)
def List(self, project_ref):
"""List models in the project."""
req = self.messages.MlProjectsModelsListRequest(
parent=project_ref.RelativeName())
return list_pager.YieldFromList(
self.client.projects_models,
req,
field='models',
batch_size_attribute='pageSize')
def Patch(self, model_ref, labels_update, description=None):
"""Update a model."""
model = self.messages.GoogleCloudMlV1Model()
update_mask = []
if labels_update.needs_update:
model.labels = labels_update.labels
update_mask.append('labels')
if description:
model.description = description
update_mask.append('description')
if not update_mask:
raise NoFieldsSpecifiedError('No updates requested.')
req = self.messages.MlProjectsModelsPatchRequest(
name=model_ref.RelativeName(),
googleCloudMlV1Model=model,
updateMask=','.join(update_mask))
return self.client.projects_models.Patch(req)

View File

@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with long-running operations (simple uri)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def GetMessagesModule(version='v1'):
return apis.GetMessagesModule('ml', version)
def GetClientInstance(version='v1', no_http=False):
return apis.GetClientInstance('ml', version, no_http=no_http)
class CloudMlOperationPoller(waiter.CloudOperationPoller):
"""Poller for Cloud ML Engine operations API.
This is necessary because the core operations library doesn't directly support
simple_uri.
"""
def __init__(self, client):
self.client = client
super(CloudMlOperationPoller, self).__init__(
self.client.client.projects_operations,
self.client.client.projects_operations)
def Poll(self, operation_ref):
return self.client.Get(operation_ref)
def GetResult(self, operation):
return operation
class OperationsClient(object):
"""Client for operations service in the Cloud ML Engine API."""
def __init__(self, version='v1'):
self.client = GetClientInstance(version)
self.messages = self.client.MESSAGES_MODULE
def List(self, project_ref):
return list_pager.YieldFromList(
self.client.projects_operations,
self.messages.MlProjectsOperationsListRequest(
name=project_ref.RelativeName()),
field='operations',
batch_size_attribute='pageSize')
def Get(self, operation_ref):
return self.client.projects_operations.Get(
self.messages.MlProjectsOperationsGetRequest(
name=operation_ref.RelativeName()))
def Cancel(self, operation_ref):
return self.client.projects_operations.Cancel(
self.messages.MlProjectsOperationsCancelRequest(
name=operation_ref.RelativeName()))
def WaitForOperation(self, operation, message=None):
"""Wait until the operation is complete or times out.
Args:
operation: The operation resource to wait on
message: str, the message to print while waiting.
Returns:
The operation resource when it has completed
Raises:
OperationTimeoutError: when the operation polling times out
OperationError: when the operation completed with an error
"""
poller = CloudMlOperationPoller(self)
if poller.IsDone(operation):
return operation
operation_ref = resources.REGISTRY.Parse(
operation.name,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='ml.projects.operations')
if message is None:
message = 'Waiting for operation [{}]'.format(operation_ref.Name())
return waiter.WaitFor(
poller, operation_ref, message,
pre_start_sleep_ms=0,
max_wait_ms=60*60*1000,
exponential_sleep_multiplier=None,
jitter_ms=None,
wait_ceiling_ms=None,
sleep_ms=5000)

View File

@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with ML predict API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core.credentials import requests
from six.moves import http_client as httplib
class InstancesEncodeError(core_exceptions.Error):
"""Indicates that error occurs while decoding the instances in http body."""
pass
class HttpRequestFailError(core_exceptions.Error):
"""Indicates that the http request fails in some way."""
pass
def _GetPrediction(url, body, headers):
"""Make http request to get prediction results."""
response = requests.GetSession().request(
'POST', url, data=body, headers=headers)
return response.status_code, response.text
def Predict(model_or_version_ref, instances, signature_name=None):
"""Performs online prediction on the input data file.
Args:
model_or_version_ref: a Resource representing either a model or a version.
instances: a list of JSON or UTF-8 encoded instances to perform
prediction on.
signature_name: name of input/output signature in the TF meta graph.
Returns:
A json object that contains predictions.
Raises:
HttpRequestFailError: if error happens with http request, or parsing
the http response.
"""
url = model_or_version_ref.SelfLink() + ':predict'
# Construct the body for the predict request.
headers = {'Content-Type': 'application/json'}
content = {'instances': instances}
if signature_name:
content['signature_name'] = signature_name
try:
body = json.dumps(content, sort_keys=True)
except (UnicodeDecodeError, TypeError):
# Python 2: UnicodeDecode Error, Python 3: TypeError
raise InstancesEncodeError('Instances cannot be JSON encoded, probably '
'because the input is not utf-8 encoded.')
# Workaround since gcloud cannot handle HttpBody properly, see b/31403673
response_status, response_body = _GetPrediction(url, body, headers)
if int(response_status) != httplib.OK:
raise HttpRequestFailError('HTTP request failed. Response: ' +
response_body)
try:
return json.loads(response_body)
except ValueError:
raise HttpRequestFailError('No JSON object could be decoded from the '
'HTTP response body: ' + response_body)
def Explain(model_or_version_ref, instances):
"""Performs online explanation on the input data file.
Args:
model_or_version_ref: a Resource representing either a model or a version.
instances: a list of JSON or UTF-8 encoded instances to perform
prediction on.
Returns:
A json object that contains explanations.
Raises:
HttpRequestFailError: if error happens with http request, or parsing
the http response.
"""
url = model_or_version_ref.SelfLink() + ':explain'
# Construct the body for the explain request.
headers = {'Content-Type': 'application/json'}
content = {'instances': instances}
try:
body = json.dumps(content, sort_keys=True)
except (UnicodeDecodeError, TypeError):
# Python 2: UnicodeDecode Error, Python 3: TypeError
raise InstancesEncodeError('Instances cannot be JSON encoded, probably '
'because the input is not utf-8 encoded.')
# Workaround since gcloud cannot handle HttpBody properly, see b/31403673
response_status, response_body = _GetPrediction(url, body, headers)
if int(response_status) != httplib.OK:
raise HttpRequestFailError('HTTP request failed. Response: ' +
response_body)
try:
return json.loads(response_body)
except ValueError:
raise HttpRequestFailError('No JSON object could be decoded from the '
'HTTP response body: ' + response_body)

View File

@@ -0,0 +1,435 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with ML versions API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import encoding
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import yaml
from googlecloudsdk.core.util import text
import six
class InvalidVersionConfigFile(exceptions.Error):
"""Error indicating an invalid Version configuration file."""
class NoFieldsSpecifiedError(exceptions.Error):
"""Error indicating an invalid Version configuration file."""
def GetMessagesModule(version='v1'):
return apis.GetMessagesModule('ml', version)
def GetClientInstance(version='v1', no_http=False):
return apis.GetClientInstance('ml', version, no_http=no_http)
class VersionsClient(object):
"""Client for the versions service of Cloud ML Engine."""
_ALLOWED_YAML_FIELDS = frozenset([
'autoScaling',
'deploymentUri',
'description',
'framework',
'labels',
'machineType',
'manualScaling',
'packageUris',
'predictionClass',
'pythonVersion',
'runtimeVersion',
'serviceAccount',
])
_CONTAINER_FIELDS = frozenset([
'container',
'routes',
])
def __init__(self, client=None, messages=None):
self.client = client or GetClientInstance()
self.messages = messages or self.client.MESSAGES_MODULE
@property
def version_class(self):
return self.messages.GoogleCloudMlV1Version
def _MakeCreateRequest(self, parent, version):
return self.messages.MlProjectsModelsVersionsCreateRequest(
parent=parent,
googleCloudMlV1Version=version)
def _MakeSetDefaultRequest(self, name):
request = self.messages.GoogleCloudMlV1SetDefaultVersionRequest()
return self.messages.MlProjectsModelsVersionsSetDefaultRequest(
name=name,
googleCloudMlV1SetDefaultVersionRequest=request)
def Create(self, model_ref, version):
"""Creates a new version in an existing model."""
return self.client.projects_models_versions.Create(
self._MakeCreateRequest(
parent=model_ref.RelativeName(),
version=version))
def Patch(self, version_ref, labels_update, description=None,
prediction_class_update=None, package_uris=None,
manual_scaling_nodes=None, auto_scaling_min_nodes=None,
auto_scaling_max_nodes=None, sampling_percentage=None,
bigquery_table_name=None):
"""Update a version."""
version = self.messages.GoogleCloudMlV1Version()
update_mask = []
if labels_update.needs_update:
version.labels = labels_update.labels
update_mask.append('labels')
if description:
version.description = description
update_mask.append('description')
if prediction_class_update is not None and prediction_class_update.needs_update:
update_mask.append('predictionClass')
version.predictionClass = prediction_class_update.value
if package_uris is not None:
update_mask.append('packageUris')
version.packageUris = package_uris
if manual_scaling_nodes is not None:
update_mask.append('manualScaling.nodes')
version.manualScaling = self.messages.GoogleCloudMlV1ManualScaling(
nodes=manual_scaling_nodes)
if auto_scaling_min_nodes is not None:
update_mask.append('autoScaling.minNodes')
version.autoScaling = self.messages.GoogleCloudMlV1AutoScaling(
minNodes=auto_scaling_min_nodes)
if auto_scaling_max_nodes is not None:
update_mask.append('autoScaling.maxNodes')
if version.autoScaling is not None:
version.autoScaling.maxNodes = auto_scaling_max_nodes
else:
version.autoScaling = self.messages.GoogleCloudMlV1AutoScaling(
maxNodes=auto_scaling_max_nodes)
if bigquery_table_name is not None:
update_mask.append('requestLoggingConfig')
version.requestLoggingConfig = self.messages.GoogleCloudMlV1RequestLoggingConfig(
bigqueryTableName=bigquery_table_name)
if sampling_percentage is not None:
if 'requestLoggingConfig' not in update_mask:
update_mask.append('requestLoggingConfig')
version.requestLoggingConfig = self.messages.GoogleCloudMlV1RequestLoggingConfig(
samplingPercentage=sampling_percentage)
else:
version.requestLoggingConfig.samplingPercentage = sampling_percentage
if not update_mask:
raise NoFieldsSpecifiedError('No updates requested.')
return self.client.projects_models_versions.Patch(
self.messages.MlProjectsModelsVersionsPatchRequest(
name=version_ref.RelativeName(),
googleCloudMlV1Version=version,
updateMask=','.join(sorted(update_mask))))
def Delete(self, version_ref):
"""Deletes a version from a model."""
return self.client.projects_models_versions.Delete(
self.messages.MlProjectsModelsVersionsDeleteRequest(
name=version_ref.RelativeName()))
def Get(self, version_ref):
"""Gets details about an existing model version."""
return self.client.projects_models_versions.Get(
self.messages.MlProjectsModelsVersionsGetRequest(
name=version_ref.RelativeName()))
def List(self, model_ref):
"""Lists the versions for a model."""
list_request = self.messages.MlProjectsModelsVersionsListRequest(
parent=model_ref.RelativeName())
return list_pager.YieldFromList(
self.client.projects_models_versions, list_request,
field='versions', batch_size_attribute='pageSize')
def SetDefault(self, version_ref):
"""Sets a model's default version."""
return self.client.projects_models_versions.SetDefault(
self._MakeSetDefaultRequest(name=version_ref.RelativeName()))
def ReadConfig(self, path, allowed_fields):
"""Read a config file and return Version object with the values.
The object is based on a YAML configuration file. The file may only
have the fields given in `allowed_fields`.
Args:
path: str, the path to the YAML file.
allowed_fields: Collection, the fields allowed in the YAML.
Returns:
A Version object (for the corresponding API version).
Raises:
InvalidVersionConfigFile: If the file contains unexpected fields.
"""
try:
data = yaml.load_path(path)
except (yaml.Error) as err:
raise InvalidVersionConfigFile(
'Could not read Version configuration file [{path}]:\n\n'
'{err}'.format(path=path, err=six.text_type(err.inner_error)))
if data:
version = encoding.DictToMessage(data, self.version_class)
specified_fields = set([f.name for f in version.all_fields() if
getattr(version, f.name)])
invalid_fields = (specified_fields - allowed_fields |
set(version.all_unrecognized_fields()))
if invalid_fields:
raise InvalidVersionConfigFile(
'Invalid {noun} [{fields}] in configuration file [{path}]. '
'Allowed fields: [{allowed}].'.format(
noun=text.Pluralize(len(invalid_fields), 'field'),
fields=', '.join(sorted(invalid_fields)),
path=path,
allowed=', '.join(sorted(allowed_fields))))
return version
def _ConfigureContainer(self, version, **kwargs):
"""Adds `container` and `routes` fields to version."""
if not any(kwargs.values()):
# Nothing related to containers was specified!
return
if not kwargs['image']: # Implied from above: some other parameter is set!
set_flags = ', '.join(
['--{}'.format(k) for k, v in sorted(kwargs.items()) if v])
raise ValueError(
'--image was not provided, but other container related flags were '
'specified. Please specify --image or remove the following flags: '
'{}'.format(set_flags))
version.container = self.messages.GoogleCloudMlV1ContainerSpec(
image=kwargs['image'])
if kwargs['command']:
version.container.command = kwargs['command']
if kwargs['args']:
version.container.args = kwargs['args']
if kwargs['env_vars']:
version.container.env = [
self.messages.GoogleCloudMlV1EnvVar(name=name, value=value)
for name, value in kwargs['env_vars'].items()]
if kwargs['ports']:
version.container.ports = [
self.messages.GoogleCloudMlV1ContainerPort(containerPort=p)
for p in kwargs['ports']
]
if kwargs['predict_route'] or kwargs['health_route']:
version.routes = self.messages.GoogleCloudMlV1RouteMap(
predict=kwargs['predict_route'],
health=kwargs['health_route']
)
def _ConfigureAutoScaling(self, version, **kwargs):
"""Adds `auto_scaling` fields to version."""
if not any(kwargs.values()):
# Nothing related to containers was specified!
return
version.autoScaling = self.messages.GoogleCloudMlV1AutoScaling()
if kwargs['min_nodes']:
version.autoScaling.minNodes = kwargs['min_nodes']
if kwargs['max_nodes']:
version.autoScaling.maxNodes = kwargs['max_nodes']
if kwargs['metrics']:
version.autoScaling.metrics = []
if 'cpu-usage' in kwargs['metrics']:
t = int(kwargs['metrics']['cpu-usage'])
version.autoScaling.metrics.append(
self.messages.GoogleCloudMlV1MetricSpec(
name=self.messages.GoogleCloudMlV1MetricSpec.NameValueValuesEnum
.CPU_USAGE,
target=t))
if 'gpu-duty-cycle' in kwargs['metrics']:
t = int(kwargs['metrics']['gpu-duty-cycle'])
version.autoScaling.metrics.append(
self.messages.GoogleCloudMlV1MetricSpec(
name=self.messages.GoogleCloudMlV1MetricSpec.NameValueValuesEnum
.GPU_DUTY_CYCLE,
target=t))
def BuildVersion(self,
name,
path=None,
deployment_uri=None,
runtime_version=None,
labels=None,
machine_type=None,
description=None,
framework=None,
python_version=None,
prediction_class=None,
package_uris=None,
accelerator_config=None,
service_account=None,
explanation_method=None,
num_integral_steps=None,
num_paths=None,
image=None,
command=None,
container_args=None,
env_vars=None,
ports=None,
predict_route=None,
health_route=None,
min_nodes=None,
max_nodes=None,
metrics=None,
containers_hidden=True,
autoscaling_hidden=True):
"""Create a Version object.
The object is based on an optional YAML configuration file and the
parameters to this method; any provided method parameters override any
provided in-file configuration.
The file may only have the fields given in
VersionsClientBase._ALLOWED_YAML_FIELDS specified; the only parameters
allowed are those that can be specified on the command line.
Args:
name: str, the name of the version object to create.
path: str, the path to the YAML file.
deployment_uri: str, the deploymentUri to set for the Version
runtime_version: str, the runtimeVersion to set for the Version
labels: Version.LabelsValue, the labels to set for the version
machine_type: str, the machine type to serve the model version on.
description: str, the version description.
framework: FrameworkValueValuesEnum, the ML framework used to train this
version of the model.
python_version: str, The version of Python used to train the model.
prediction_class: str, the FQN of a Python class implementing the Model
interface for custom prediction.
package_uris: list of str, Cloud Storage URIs containing user-supplied
Python code to use.
accelerator_config: an accelerator config message object.
service_account: Specifies the service account for resource access
control.
explanation_method: Enables explanations and selects the explanation
method. Valid options are 'integrated-gradients' and 'sampled-shapley'.
num_integral_steps: Number of integral steps for Integrated Gradients and
XRAI.
num_paths: Number of paths for Sampled Shapley.
image: The container image to deploy.
command: Entrypoint for the container image.
container_args: The command-line args to pass the container.
env_vars: The environment variables to set on the container.
ports: The ports to which traffic will be sent in the container.
predict_route: The HTTP path within the container that predict requests
are sent to.
health_route: The HTTP path within the container that health checks are
sent to.
min_nodes: The minimum number of nodes to scale this model under load.
max_nodes: The maximum number of nodes to scale this model under load.
metrics: List of key-value pairs to set as metrics' target for
autoscaling.
containers_hidden: Whether or not container-related fields are hidden on
this track.
autoscaling_hidden: Whether or not autoscaling fields are hidden on this
track.
Returns:
A Version object (for the corresponding API version).
Raises:
InvalidVersionConfigFile: If the file contains unexpected fields.
"""
if path:
allowed_fields = self._ALLOWED_YAML_FIELDS
if not containers_hidden:
allowed_fields |= self._CONTAINER_FIELDS
version = self.ReadConfig(path, allowed_fields)
else:
version = self.version_class()
additional_fields = {
'name': name,
'deploymentUri': deployment_uri,
'runtimeVersion': runtime_version,
'labels': labels,
'machineType': machine_type,
'description': description,
'framework': framework,
'pythonVersion': python_version,
'predictionClass': prediction_class,
'packageUris': package_uris,
'acceleratorConfig': accelerator_config,
'serviceAccount': service_account
}
explanation_config = None
if explanation_method == 'integrated-gradients':
explanation_config = self.messages.GoogleCloudMlV1ExplanationConfig()
ig_config = self.messages.GoogleCloudMlV1IntegratedGradientsAttribution()
ig_config.numIntegralSteps = num_integral_steps
explanation_config.integratedGradientsAttribution = ig_config
elif explanation_method == 'sampled-shapley':
explanation_config = self.messages.GoogleCloudMlV1ExplanationConfig()
shap_config = self.messages.GoogleCloudMlV1SampledShapleyAttribution()
shap_config.numPaths = num_paths
explanation_config.sampledShapleyAttribution = shap_config
elif explanation_method == 'xrai':
explanation_config = self.messages.GoogleCloudMlV1ExplanationConfig()
xrai_config = self.messages.GoogleCloudMlV1XraiAttribution()
xrai_config.numIntegralSteps = num_integral_steps
explanation_config.xraiAttribution = xrai_config
if explanation_config is not None:
additional_fields['explanationConfig'] = explanation_config
if not containers_hidden:
self._ConfigureContainer(
version,
image=image,
command=command,
args=container_args,
env_vars=env_vars,
ports=ports,
predict_route=predict_route,
health_route=health_route)
if not autoscaling_hidden:
self._ConfigureAutoScaling(
version, min_nodes=min_nodes, max_nodes=max_nodes, metrics=metrics)
for field_name, value in additional_fields.items():
if value is not None:
setattr(version, field_name, value)
return version