330 lines
12 KiB
Python
330 lines
12 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2016 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for dealing with ML jobs API."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
from apitools.base.py import encoding
|
|
from apitools.base.py import list_pager
|
|
from googlecloudsdk.api_lib.util import apis
|
|
from googlecloudsdk.command_lib.util.apis import arg_utils
|
|
from googlecloudsdk.core import exceptions
|
|
from googlecloudsdk.core import properties
|
|
from googlecloudsdk.core import resources
|
|
from googlecloudsdk.core import yaml
|
|
|
|
|
|
class NoFieldsSpecifiedError(exceptions.Error):
|
|
"""Error indicating that no updates were requested in a Patch operation."""
|
|
|
|
|
|
class NoPackagesSpecifiedError(exceptions.Error):
|
|
"""Error that no packages were specified for non-custom training."""
|
|
|
|
|
|
def GetMessagesModule(version='v1'):
|
|
return apis.GetMessagesModule('ml', version)
|
|
|
|
|
|
def GetClientInstance(version='v1', no_http=False):
|
|
return apis.GetClientInstance('ml', version, no_http=no_http)
|
|
|
|
|
|
class JobsClient(object):
|
|
"""Client for jobs service in the Cloud ML Engine API."""
|
|
|
|
def __init__(self, client=None, messages=None,
|
|
short_message_prefix='GoogleCloudMlV1', client_version='v1'):
|
|
self.client = client or GetClientInstance(client_version)
|
|
self.messages = messages or self.client.MESSAGES_MODULE
|
|
self._short_message_prefix = short_message_prefix
|
|
|
|
def GetShortMessage(self, short_message_name):
|
|
return getattr(self.messages,
|
|
'{prefix}{name}'.format(prefix=self._short_message_prefix,
|
|
name=short_message_name), None)
|
|
|
|
@property
|
|
def state_enum(self):
|
|
return self.messages.GoogleCloudMlV1Job.StateValueValuesEnum
|
|
|
|
def List(self, project_ref):
|
|
req = self.messages.MlProjectsJobsListRequest(
|
|
parent=project_ref.RelativeName())
|
|
return list_pager.YieldFromList(
|
|
self.client.projects_jobs, req, field='jobs',
|
|
batch_size_attribute='pageSize')
|
|
|
|
@property
|
|
def job_class(self):
|
|
return self.messages.GoogleCloudMlV1Job
|
|
|
|
@property
|
|
def training_input_class(self):
|
|
return self.messages.GoogleCloudMlV1TrainingInput
|
|
|
|
@property
|
|
def prediction_input_class(self):
|
|
return self.messages.GoogleCloudMlV1PredictionInput
|
|
|
|
def _MakeCreateRequest(self, parent=None, job=None):
|
|
return self.messages.MlProjectsJobsCreateRequest(
|
|
parent=parent,
|
|
googleCloudMlV1Job=job)
|
|
|
|
def Create(self, project_ref, job):
|
|
return self.client.projects_jobs.Create(
|
|
self._MakeCreateRequest(
|
|
parent=project_ref.RelativeName(),
|
|
job=job))
|
|
|
|
def Cancel(self, job_ref):
|
|
"""Cancels given job."""
|
|
req = self.messages.MlProjectsJobsCancelRequest(name=job_ref.RelativeName())
|
|
return self.client.projects_jobs.Cancel(req)
|
|
|
|
def Get(self, job_ref):
|
|
req = self.messages.MlProjectsJobsGetRequest(name=job_ref.RelativeName())
|
|
return self.client.projects_jobs.Get(req)
|
|
|
|
def Patch(self, job_ref, labels_update):
|
|
"""Update a job."""
|
|
job = self.job_class()
|
|
update_mask = []
|
|
if labels_update.needs_update:
|
|
job.labels = labels_update.labels
|
|
update_mask.append('labels')
|
|
if not update_mask:
|
|
raise NoFieldsSpecifiedError('No updates requested.')
|
|
req = self.messages.MlProjectsJobsPatchRequest(
|
|
name=job_ref.RelativeName(),
|
|
googleCloudMlV1Job=job,
|
|
updateMask=','.join(update_mask)
|
|
)
|
|
return self.client.projects_jobs.Patch(req)
|
|
|
|
def BuildTrainingJob(self,
|
|
path=None,
|
|
module_name=None,
|
|
job_name=None,
|
|
trainer_uri=None,
|
|
region=None,
|
|
job_dir=None,
|
|
scale_tier=None,
|
|
user_args=None,
|
|
runtime_version=None,
|
|
python_version=None,
|
|
network=None,
|
|
service_account=None,
|
|
labels=None,
|
|
kms_key=None,
|
|
custom_train_server_config=None,
|
|
enable_web_access=None):
|
|
"""Builds a Cloud ML Engine Job from a config file and/or flag values.
|
|
|
|
Args:
|
|
path: path to a yaml configuration file
|
|
module_name: value to set for moduleName field (overrides yaml file)
|
|
job_name: value to set for jobName field (overrides yaml file)
|
|
trainer_uri: List of values to set for trainerUri field (overrides yaml
|
|
file)
|
|
region: compute region in which to run the job (overrides yaml file)
|
|
job_dir: Cloud Storage working directory for the job (overrides yaml
|
|
file)
|
|
scale_tier: ScaleTierValueValuesEnum the scale tier for the job
|
|
(overrides yaml file)
|
|
user_args: [str]. A list of arguments to pass through to the job.
|
|
(overrides yaml file)
|
|
runtime_version: the runtime version in which to run the job (overrides
|
|
yaml file)
|
|
python_version: the Python version in which to run the job (overrides
|
|
yaml file)
|
|
network: user network to which the job should be peered with (overrides
|
|
yaml file)
|
|
service_account: A service account (email address string) to use for the
|
|
job.
|
|
labels: Job.LabelsValue, the Cloud labels for the job
|
|
kms_key: A customer-managed encryption key to use for the job.
|
|
custom_train_server_config: jobs_util.CustomTrainingInputServerConfig,
|
|
configuration object for custom server parameters.
|
|
enable_web_access: whether to enable the interactive shell for the job.
|
|
Raises:
|
|
NoPackagesSpecifiedError: if a non-custom job was specified without any
|
|
trainer_uris.
|
|
Returns:
|
|
A constructed Job object.
|
|
"""
|
|
job = self.job_class()
|
|
|
|
# TODO(b/123467089): Remove yaml file loading here, only parse data objects
|
|
if path:
|
|
data = yaml.load_path(path)
|
|
if data:
|
|
job = encoding.DictToMessage(data, self.job_class)
|
|
|
|
if job_name:
|
|
job.jobId = job_name
|
|
|
|
if labels is not None:
|
|
job.labels = labels
|
|
|
|
if not job.trainingInput:
|
|
job.trainingInput = self.training_input_class()
|
|
additional_fields = {
|
|
'pythonModule': module_name,
|
|
'args': user_args,
|
|
'packageUris': trainer_uri,
|
|
'region': region,
|
|
'jobDir': job_dir,
|
|
'scaleTier': scale_tier,
|
|
'runtimeVersion': runtime_version,
|
|
'pythonVersion': python_version,
|
|
'network': network,
|
|
'serviceAccount': service_account,
|
|
'enableWebAccess': enable_web_access,
|
|
}
|
|
for field_name, value in additional_fields.items():
|
|
if value is not None:
|
|
setattr(job.trainingInput, field_name, value)
|
|
|
|
if kms_key:
|
|
arg_utils.SetFieldInMessage(job,
|
|
'trainingInput.encryptionConfig.kmsKeyName',
|
|
kms_key)
|
|
|
|
if custom_train_server_config:
|
|
for field_name, value in custom_train_server_config.GetFieldMap().items():
|
|
if value is not None:
|
|
if (field_name.endswith('Config') and
|
|
not field_name.endswith('TfConfig')):
|
|
if value['imageUri']:
|
|
arg_utils.SetFieldInMessage(
|
|
job,
|
|
'trainingInput.{}.imageUri'.format(field_name),
|
|
value['imageUri'])
|
|
if value['acceleratorConfig']['type']:
|
|
arg_utils.SetFieldInMessage(
|
|
job,
|
|
'trainingInput.{}.acceleratorConfig.type'.format(field_name),
|
|
value['acceleratorConfig']['type'])
|
|
if value['acceleratorConfig']['count']:
|
|
arg_utils.SetFieldInMessage(
|
|
job,
|
|
'trainingInput.{}.acceleratorConfig.count'.format(field_name),
|
|
value['acceleratorConfig']['count'])
|
|
if field_name == 'workerConfig' and value['tpuTfVersion']:
|
|
arg_utils.SetFieldInMessage(
|
|
job,
|
|
'trainingInput.{}.tpuTfVersion'.format(field_name),
|
|
value['tpuTfVersion'])
|
|
else:
|
|
setattr(job.trainingInput, field_name, value)
|
|
|
|
if not self.HasPackageURIs(job) and not self.IsCustomContainerTraining(job):
|
|
raise NoPackagesSpecifiedError('Non-custom jobs must have packages.')
|
|
|
|
return job
|
|
|
|
def HasPackageURIs(self, job):
|
|
return bool(job.trainingInput.packageUris)
|
|
|
|
def IsCustomContainerTraining(self, job):
|
|
return bool(job.trainingInput.masterConfig and
|
|
job.trainingInput.masterConfig.imageUri)
|
|
|
|
def BuildBatchPredictionJob(self,
|
|
job_name=None,
|
|
model_dir=None,
|
|
model_name=None,
|
|
version_name=None,
|
|
input_paths=None,
|
|
data_format=None,
|
|
output_path=None,
|
|
region=None,
|
|
runtime_version=None,
|
|
max_worker_count=None,
|
|
batch_size=None,
|
|
signature_name=None,
|
|
labels=None,
|
|
accelerator_count=None,
|
|
accelerator_type=None):
|
|
"""Builds a Cloud ML Engine Job for batch prediction from flag values.
|
|
|
|
Args:
|
|
job_name: value to set for jobName field
|
|
model_dir: str, Google Cloud Storage location of the model files
|
|
model_name: str, value to set for modelName field
|
|
version_name: str, value to set for versionName field
|
|
input_paths: list of input files
|
|
data_format: format of the input files
|
|
output_path: single value for the output location
|
|
region: compute region in which to run the job
|
|
runtime_version: the runtime version in which to run the job
|
|
max_worker_count: int, the maximum number of workers to use
|
|
batch_size: int, the number of records per batch sent to Tensorflow
|
|
signature_name: str, name of input/output signature in the TF meta graph
|
|
labels: Job.LabelsValue, the Cloud labels for the job
|
|
accelerator_count: int, The number of accelerators to attach to the
|
|
machines
|
|
accelerator_type: AcceleratorsValueListEntryValuesEnum, The type of
|
|
accelerator to add to machine.
|
|
Returns:
|
|
A constructed Job object.
|
|
"""
|
|
project_id = properties.VALUES.core.project.GetOrFail()
|
|
|
|
if accelerator_type:
|
|
accelerator_config_msg = self.GetShortMessage('AcceleratorConfig')
|
|
accelerator_config = accelerator_config_msg(count=accelerator_count,
|
|
type=accelerator_type)
|
|
else:
|
|
accelerator_config = None
|
|
|
|
prediction_input = self.prediction_input_class(
|
|
inputPaths=input_paths,
|
|
outputPath=output_path,
|
|
region=region,
|
|
runtimeVersion=runtime_version,
|
|
maxWorkerCount=max_worker_count,
|
|
batchSize=batch_size,
|
|
accelerator=accelerator_config
|
|
)
|
|
|
|
prediction_input.dataFormat = prediction_input.DataFormatValueValuesEnum(
|
|
data_format)
|
|
if model_dir:
|
|
prediction_input.uri = model_dir
|
|
elif version_name:
|
|
version_ref = resources.REGISTRY.Parse(
|
|
version_name, collection='ml.projects.models.versions',
|
|
params={'modelsId': model_name, 'projectsId': project_id})
|
|
prediction_input.versionName = version_ref.RelativeName()
|
|
else:
|
|
model_ref = resources.REGISTRY.Parse(
|
|
model_name, collection='ml.projects.models',
|
|
params={'projectsId': project_id})
|
|
prediction_input.modelName = model_ref.RelativeName()
|
|
|
|
if signature_name:
|
|
prediction_input.signatureName = signature_name
|
|
|
|
return self.job_class(
|
|
jobId=job_name,
|
|
predictionInput=prediction_input,
|
|
labels=labels
|
|
)
|