feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
# Copyright 2016 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: This file still exists explicitly in google3 because usages of
# "from google import cloud" need the file to make strict deps work when the
# targets depend on "//third_party/py/google/cloud:core" (:core globs "*.py").

View File

@@ -0,0 +1,15 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Google Cloud Machine Learning Prediction SDK.
"""

View File

@@ -0,0 +1,19 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for dealing with I/O from ML pipelines.
"""
from google.cloud.ml.io.coders import CsvCoder
from google.cloud.ml.io.coders import ExampleProtoCoder
from google.cloud.ml.io.coders import JsonCoder

View File

@@ -0,0 +1,504 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for dealing with I/O from ML pipelines.
"""
import csv
import datetime
import json
import logging
import apache_beam as beam
from six.moves import cStringIO
import yaml
from google.cloud.ml.util import _decoders
from google.cloud.ml.util import _file
# TODO(user): Use a ProtoCoder once b/29055158 is resolved.
class ExampleProtoCoder(beam.coders.Coder):
"""A coder to encode and decode TensorFlow Example objects."""
def __init__(self):
import tensorflow as tf # pylint: disable=g-import-not-at-top
self._tf_train = tf.train
def encode(self, example_proto):
"""Encodes Tensorflow example object to a serialized string.
Args:
example_proto: A Tensorflow Example object
Returns:
String.
"""
return example_proto.SerializeToString()
def decode(self, serialized_str):
"""Decodes a serialized string into a Tensorflow Example object.
Args:
serialized_str: string
Returns:
Tensorflow Example object.
"""
example = self._tf_train.Example()
example.ParseFromString(serialized_str)
return example
class JsonCoder(beam.coders.Coder):
"""A coder to encode and decode JSON formatted data."""
def __init__(self, indent=None):
self._indent = indent
def encode(self, obj):
"""Encodes a python object into a JSON string.
Args:
obj: python object.
Returns:
JSON string.
"""
# Supplying seperators to avoid unnecessary trailing whitespaces.
return json.dumps(obj, indent=self._indent, separators=(',', ': '))
def decode(self, json_string):
"""Decodes a JSON string to a python object.
Args:
json_string: A JSON string.
Returns:
A python object.
"""
return json.loads(json_string)
class CsvCoder(beam.coders.Coder):
"""A coder to encode and decode CSV formatted data.
"""
class _WriterWrapper(object):
"""A wrapper for csv.writer / csv.DictWriter to make it picklable."""
def __init__(self, column_names, delimiter, decode_to_dict):
self._state = (column_names, delimiter, decode_to_dict)
self._buffer = cStringIO()
if decode_to_dict:
self._writer = csv.DictWriter(
self._buffer,
column_names,
lineterminator='',
delimiter=delimiter)
else:
self._writer = csv.writer(
self._buffer,
lineterminator='',
delimiter=delimiter)
def encode_record(self, record):
self._writer.writerow(record)
value = self._buffer.getvalue()
# Reset the buffer.
self._buffer.seek(0)
self._buffer.truncate(0)
return value
def __getstate__(self):
return self._state
def __setstate__(self, state):
self.__init__(*state)
def __init__(self, column_names, numeric_column_names, delimiter=',',
decode_to_dict=True, fail_on_error=True,
skip_initial_space=False):
"""Initializes CsvCoder.
Args:
column_names: Tuple of strings. Order must match the order in the file.
numeric_column_names: Tuple of strings. Contains column names that are
numeric. Every name in numeric_column_names must also be in
column_names.
delimiter: A one-character string used to separate fields.
decode_to_dict: Boolean indicating whether the docoder should generate a
dictionary instead of a raw sequence. True by default.
fail_on_error: Whether to fail if a corrupt row is found. Default is True.
skip_initial_space: When True, whitespace immediately following the
delimiter is ignored when reading.
"""
self._decoder = _decoders.CsvDecoder(
column_names, numeric_column_names, delimiter, decode_to_dict,
fail_on_error, skip_initial_space)
self._encoder = self._WriterWrapper(
column_names=column_names,
delimiter=delimiter,
decode_to_dict=decode_to_dict)
def decode(self, csv_line):
"""Decode csv line into a python dict.
Args:
csv_line: String. One csv line from the file.
Returns:
Python dict where the keys are the column names from the file. The dict
values are strings or numbers depending if a column name was listed in
numeric_column_names. Missing string columns have the value '', while
missing numeric columns have the value None. If there is an error in
parsing csv_line, a python dict is returned where every value is '' or
None.
Raises:
Exception: The number of columns to not match.
"""
return self._decoder.decode(csv_line)
def encode(self, python_data):
"""Encode python dict to a csv-formatted string.
Args:
python_data: A python collection, depending on the value of decode_to_dict
it will be a python dictionary where the keys are the column names or
a sequence.
Returns:
A csv-formatted string. The order of the columns is given by column_names.
"""
return self._encoder.encode_record(python_data)
class YamlCoder(beam.coders.Coder):
"""A coder to encode and decode YAML formatted data."""
def __init__(self):
"""Trying to use the efficient libyaml library to encode and decode yaml.
If libyaml is not available than we fallback to use the native yaml library,
use with caution; it is far less efficient, uses excessive memory, and leaks
memory.
"""
# TODO(user): Always use libyaml once possible.
if yaml.__with_libyaml__:
self._safe_dumper = yaml.CSafeDumper
self._safe_loader = yaml.CSafeLoader
else:
logging.warning(
'Can\'t find libyaml so it is not used for YamlCoder, the '
'implementation used is far slower and has a memory leak.')
self._safe_dumper = yaml.SafeDumper
self._safe_loader = yaml.SafeLoader
def encode(self, obj):
"""Encodes a python object into a YAML string.
Args:
obj: python object.
Returns:
YAML string.
"""
return yaml.dump(
obj,
default_flow_style=False,
encoding='utf-8',
Dumper=self._safe_dumper)
def decode(self, yaml_string):
"""Decodes a YAML string to a python object.
Args:
yaml_string: A YAML string.
Returns:
A python object.
"""
return yaml.load(yaml_string, Loader=self._safe_loader)
class MetadataCoder(beam.coders.Coder):
"""A coder to encode and decode CloudML metadata."""
def encode(self, obj):
"""Encodes a python object into a YAML string.
Args:
obj: python object.
Returns:
JSON string.
"""
return JsonCoder(indent=1).encode(obj)
def decode(self, metadata_string):
"""Decodes a metadata string to a python object.
Args:
metadata_string: A metadata string, either in json or yaml format.
Returns:
A python object.
"""
return self._decode_internal(metadata_string)
@classmethod
def load_from(cls, path):
"""Reads a metadata file.
Assums it's in json format by default and falls back to yaml format if that
fails.
Args:
path: A metadata file path string.
Returns:
A decoded metadata object.
"""
data = _file.load_file(path)
return cls._decode_internal(data)
@staticmethod
def _decode_internal(metadata_string):
try:
return JsonCoder().decode(metadata_string)
except ValueError:
return YamlCoder().decode(metadata_string)
class TrainingJobRequestCoder(beam.coders.Coder):
"""Custom coder for a TrainingJobRequest object."""
def encode(self, training_job_request):
"""Encode a TrainingJobRequest to a JSON string.
Args:
training_job_request: A TrainingJobRequest object.
Returns:
A JSON string
"""
d = {}
d.update(training_job_request.__dict__)
# We need to convert timedelta values for values that are json encodable.
for k in ['timeout', 'polling_interval']:
if d[k]:
d[k] = d[k].total_seconds()
return json.dumps(d)
def decode(self, training_job_request_string):
"""Decode a JSON string representing a TrainingJobRequest.
Args:
training_job_request_string: A string representing a TrainingJobRequest.
Returns:
TrainingJobRequest object.
"""
r = TrainingJobRequest()
d = json.loads(training_job_request_string)
# We need to parse timedelata values.
for k in ['timeout', 'polling_interval']:
if d[k]:
d[k] = datetime.timedelta(seconds=d[k])
r.__dict__.update(d)
return r
class TrainingJobResultCoder(beam.coders.Coder):
"""Custom coder for TrainingJobResult."""
def encode(self, training_job_result):
"""Encode a TrainingJobResult object into a JSON string.
Args:
training_job_result: A TrainingJobResult object.
Returns:
A JSON string
"""
d = {}
d.update(training_job_result.__dict__)
# We need to properly encode the request.
if d['training_request'] is not None:
coder = TrainingJobRequestCoder()
d['training_request'] = coder.encode(d['training_request'])
return json.dumps(d)
def decode(self, training_job_result_string):
"""Decode a string to a TrainingJobResult object.
Args:
training_job_result_string: A string representing a TrainingJobResult.
Returns:
A TrainingJobResult object.
"""
r = TrainingJobResult()
d = json.loads(training_job_result_string)
# We need to properly encode the request.
if d['training_request'] is not None:
coder = TrainingJobRequestCoder()
d['training_request'] = coder.decode(d['training_request'])
r.__dict__.update(d)
return r
class TrainingJobRequest(object):
"""This class contains the parameters for running a training job.
"""
def __init__(self,
parent=None,
job_name=None,
job_args=None,
package_uris=None,
python_module=None,
timeout=None,
polling_interval=datetime.timedelta(seconds=30),
scale_tier=None,
hyperparameters=None,
region=None,
master_type=None,
worker_type=None,
ps_type=None,
worker_count=None,
ps_count=None,
endpoint=None,
runtime_version=None):
"""Construct an instance of TrainingSpec.
Args:
parent: The project name. This is named parent because the parent object
of jobs is the project.
job_name: A job name. This must be unique within the project.
job_args: Additional arguments to pass to the job.
package_uris: A list of URIs to tarballs with the training program.
python_module: The module name of the python file within the tarball.
timeout: A datetime.timedelta expressing the amount of time to wait before
giving up. The timeout applies to a single invocation of the process
method in TrainModelDo. A DoFn can be retried several times before a
pipeline fails.
polling_interval: A datetime.timedelta to represent the amount of time to
wait between requests polling for the files.
scale_tier: Google Cloud ML tier to run in.
hyperparameters: (Optional) Hyperparameter config to use for the job.
region: (Optional) Google Cloud region in which to run.
master_type: Master type to use with a CUSTOM scale tier.
worker_type: Worker type to use with a CUSTOM scale tier.
ps_type: Parameter Server type to use with a CUSTOM scale tier.
worker_count: Worker count to use with a CUSTOM scale tier.
ps_count: Parameter Server count to use with a CUSTOM scale tier.
endpoint: (Optional) The endpoint for the Cloud ML API.
runtime_version: (Optional) the Google Cloud ML runtime version to use.
"""
self.parent = parent
self.job_name = job_name
self.job_args = job_args
self.python_module = python_module
self.package_uris = package_uris
self.scale_tier = scale_tier
self.hyperparameters = hyperparameters
self.region = region
self.master_type = master_type
self.worker_type = worker_type
self.ps_type = ps_type
self.worker_count = worker_count
self.ps_count = ps_count
self.timeout = timeout
self.polling_interval = polling_interval
self.endpoint = endpoint
self.runtime_version = runtime_version
@property
def project(self):
return self.parent
def copy(self):
"""Return a copy of the object."""
r = TrainingJobRequest()
r.__dict__.update(self.__dict__)
return r
def __eq__(self, o):
for f in ['parent', 'job_name', 'job_args', 'package_uris', 'python_module',
'timeout', 'polling_interval', 'endpoint', 'hyperparameters',
'scale_tier', 'worker_type', 'ps_type', 'master_type', 'region',
'ps_count', 'worker_count', 'runtime_version']:
if getattr(self, f) != getattr(o, f):
return False
return True
def __ne__(self, o):
return not self == o
def __repr__(self):
fields = []
for k, v in self.__dict__.iteritems():
fields.append('{0}={1}'.format(k, v))
return 'TrainingJobRequest({0})'.format(', '.join(fields))
# Register coder for this class.
beam.coders.registry.register_coder(TrainingJobRequest, TrainingJobRequestCoder)
class TrainingJobResult(object):
"""Result of training a model."""
def __init__(self):
# A copy of the training request that created the job.
self.training_request = None
# An instance of TrainingJobMetadata as returned by the API.
self.training_job_metadata = None
# At most one of error and training_job_result will be specified.
# These fields will only be supplied if the job completed.
# training_job_result will be provided if the job completed successfully
# and error will be supplied otherwise.
self.error = None
self.training_job_result = None
def __eq__(self, o):
for f in ['training_request', 'training_job_metadata', 'error',
'training_job_result']:
if getattr(self, f) != getattr(o, f):
return False
return True
def __ne__(self, o):
return not self == o
def __repr__(self):
fields = []
for k, v in self.__dict__.iteritems():
fields.append('{0}={1}'.format(k, v))
return 'TrainingJobResult({0})'.format(', '.join(fields))
# Register coder for this class.
beam.coders.registry.register_coder(TrainingJobResult, TrainingJobResultCoder)

View File

@@ -0,0 +1,47 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-import-not-at-top
"""Classes and methods for predictions on a trained machine learning model.
"""
from ._interfaces import Model
from ._interfaces import PredictionClient
from .custom_code_utils import create_user_model
from .custom_code_utils import load_custom_class
from .prediction_lib import create_client
from .prediction_lib import create_model
from .prediction_lib import local_predict
from .prediction_utils import ALIAS_TIME
from .prediction_utils import BaseModel
from .prediction_utils import COLUMNARIZE_TIME
from .prediction_utils import copy_model_to_local
from .prediction_utils import decode_base64
from .prediction_utils import detect_framework
from .prediction_utils import does_signature_contain_str
from .prediction_utils import ENCODE_TIME
from .prediction_utils import ENGINE
from .prediction_utils import ENGINE_RUN_TIME
from .prediction_utils import FRAMEWORK
from .prediction_utils import LOCAL_MODEL_PATH
from .prediction_utils import PredictionError
from .prediction_utils import ROWIFY_TIME
from .prediction_utils import SCIKIT_LEARN_FRAMEWORK_NAME
from .prediction_utils import SESSION_RUN_ENGINE_NAME
from .prediction_utils import SESSION_RUN_TIME
from .prediction_utils import SIGNATURE_KEY
from .prediction_utils import SK_XGB_FRAMEWORK_NAME
from .prediction_utils import Stats
from .prediction_utils import TENSORFLOW_FRAMEWORK_NAME
from .prediction_utils import Timer
from .prediction_utils import UNALIAS_TIME
from .prediction_utils import XGBOOST_FRAMEWORK_NAME

View File

@@ -0,0 +1,153 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interfaces and other classes for providing custom code for prediction."""
class Model(object):
"""A Model performs predictions on a given list of instances.
The input instances are the raw values sent by the user. It is the
responsibility of a Model to translate these instances into
actual predictions.
The input instances and the output use python data types. The input
instances have been decoded prior to being passed to the predict
method. The output, which should use python data types is
encoded after being returned from the predict method.
"""
def predict(self, instances, **kwargs):
"""Returns predictions for the provided instances.
Instances are the decoded values from the request. Clients need not worry
about decoding json nor base64 decoding.
Args:
instances: A list of instances, as described in the API.
**kwargs: Additional keyword arguments, will be passed into the client's
predict method.
Returns:
A list of outputs containing the prediction results.
Raises:
PredictionError: If an error occurs during prediction.
"""
raise NotImplementedError()
@classmethod
def from_path(cls, model_path):
"""Creates a model using the given model path.
Path is useful, e.g., to load files from the exported directory containing
the model.
Args:
model_path: The local directory that contains the exported model file
along with any additional files uploaded when creating the version
resource.
Returns:
An instance implementing this Model class.
"""
raise NotImplementedError()
class PredictionClient(object):
"""A client for Prediction.
No assumptions are made about whether the prediction happens in process,
across processes, or even over the network.
The inputs, unlike Model.predict, have already been "columnarized", i.e.,
a dict mapping input names to values for a whole batch, much like
Session.run's feed_dict parameter. The return value is the same format.
"""
def __init__(self, *args, **kwargs):
pass
def predict(self, inputs, **kwargs):
"""Produces predictions for the given inputs.
Args:
inputs: A dict mapping input names to values.
**kwargs: Additional keyword arguments for prediction
Returns:
A dict mapping output names to output values, similar to the input
dict.
"""
raise NotImplementedError()
def explain(self, inputs, **kwargs):
"""Produces predictions for the given inputs.
Args:
inputs: A dict mapping input names to values.
**kwargs: Additional keyword arguments for prediction
Returns:
A dict mapping output names to output values, similar to the input
dict.
"""
raise NotImplementedError()
class Processor(object):
"""Interface for constructing instance processors."""
@classmethod
def from_model_path(cls, model_path):
"""Creates a processor using the given model path.
Args:
model_path: The path to the stored model.
Returns:
An instance implementing this Processor class.
"""
raise NotImplementedError()
class Preprocessor(object):
"""Interface for processing a list of instances before prediction."""
def preprocess(self, instances, **kwargs):
"""The preprocessing function.
Args:
instances: A list of instances, as provided to the predict() method.
**kwargs: Additional keyword arguments for preprocessing.
Returns:
The processed instance to use in the predict() method.
"""
raise NotImplementedError()
class Postprocessor(object):
"""Interface for processing a list of instances after prediction."""
def postprocess(self, instances, **kwargs):
"""The postprocessing function.
Args:
instances: A list of instances, as provided to the predict() method.
**kwargs: Additional keyword arguments for postprocessing.
Returns:
The processed instance to return as the final prediction output.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,129 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for loading user provided prediction code.
"""
import inspect
import json
import os
import pydoc # used for importing python classes from their FQN
import sys
from ._interfaces import Model
from .prediction_utils import PredictionError
_PREDICTION_CLASS_KEY = "prediction_class"
def create_user_model(model_path, unused_flags):
"""Loads in the user specified custom Model class.
Args:
model_path: The path to either session_bundle or SavedModel.
unused_flags: Required since model creation for other frameworks needs the
additional flags params. And model creation is called in a framework
agnostic manner.
Returns:
An instance of a Model.
Returns None if the user didn't specify the name of the custom
python class to load in the create_version_request.
Raises:
PredictionError: for any of the following:
(1) the user provided python model class cannot be found
(2) if the loaded class does not implement the Model interface.
"""
prediction_class = load_custom_class()
if not prediction_class:
return None
_validate_prediction_class(prediction_class)
return prediction_class.from_path(model_path)
def load_custom_class():
"""Loads in the user specified custom class.
Returns:
An instance of a class specified by the user in the `create_version_request`
or None if no such class was specified.
Raises:
PredictionError: if the user provided python class cannot be found.
"""
create_version_json = os.environ.get("create_version_request")
if not create_version_json:
return None
create_version_request = json.loads(create_version_json)
if not create_version_request:
return None
version = create_version_request.get("version")
if not version:
return None
class_name = version.get(_PREDICTION_CLASS_KEY)
if not class_name:
return None
custom_class = pydoc.locate(class_name)
# TODO(user): right place to generate errors?
if not custom_class:
package_uris = [str(s) for s in version.get("package_uris")]
raise PredictionError(
PredictionError.INVALID_USER_CODE,
"%s cannot be found. Please make sure "
"(1) %s is the fully qualified function "
"name, and (2) it uses the correct package "
"name as provided by the package_uris: %s" %
(class_name, _PREDICTION_CLASS_KEY, package_uris))
return custom_class
def _validate_prediction_class(user_class):
"""Validates a user provided implementation of Model class.
Args:
user_class: The user provided custom Model class.
Raises:
PredictionError: for any of the following:
(1) the user model class does not have the correct method signatures for
the predict method
"""
user_class_name = user_class.__name__
# Since the user doesn't have access to our Model class. We can only inspect
# the user_class to check if it conforms to the Model interface.
if not hasattr(user_class, "from_path"):
raise PredictionError(
PredictionError.INVALID_USER_CODE,
"User provided model class %s must implement the from_path method." %
user_class_name)
if not hasattr(user_class, "predict"):
raise PredictionError(PredictionError.INVALID_USER_CODE,
"The provided model class, %s, is missing the "
"required predict method." % user_class_name)
# Check the predict method has the correct number of arguments
if sys.version_info.major == 2:
user_signature = inspect.getargspec(user_class.predict).args # pylint: disable=deprecated-method
model_signature = inspect.getargspec(Model.predict).args # pylint: disable=deprecated-method
else:
user_signature = inspect.getfullargspec(user_class.predict).args
model_signature = inspect.getfullargspec(Model.predict).args
user_predict_num_args = len(user_signature)
predict_num_args = len(model_signature)
if predict_num_args is not user_predict_num_args:
raise PredictionError(PredictionError.INVALID_USER_CODE,
"The provided model class, %s, has a predict method "
"with an invalid signature. Expected signature: %s "
"User signature: %s" %
(user_class_name, model_signature, user_signature))

View File

@@ -0,0 +1,16 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-import-not-at-top
"""Framework specific classes and methods for predictions.
"""

View File

@@ -0,0 +1,88 @@
# Copyright 2022 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions for BQML models trained with TRANSFORM clause."""
import logging
from bigquery_ml_utils import transform_predictor
from google.cloud.ml.prediction import copy_model_to_local
from google.cloud.ml.prediction import ENGINE
from google.cloud.ml.prediction import ENGINE_RUN_TIME
from google.cloud.ml.prediction import FRAMEWORK
from google.cloud.ml.prediction import LOCAL_MODEL_PATH
from google.cloud.ml.prediction import PredictionClient
from google.cloud.ml.prediction import PredictionError
from google.cloud.ml.prediction import SESSION_RUN_TIME
from google.cloud.ml.prediction import Stats
from google.cloud.ml.prediction.frameworks.sk_xg_prediction_lib import SklearnModel
BQML_TRANSFORM_FRAMEWORK_NAME = "bqml-transform"
class BqmlTransformModel(SklearnModel):
"""The implementation of BQML's Model with TRANSFORM clause."""
def predict(self, instances, stats=None, **kwargs):
stats = stats or Stats()
with stats.time(ENGINE_RUN_TIME):
return self._client.predict(instances, stats=stats, **kwargs)
class BqmlTransformClient(PredictionClient):
"""The implementation of BQML's TRANSFORM Client."""
def __init__(self, predictor):
self._predictor = predictor
def predict(self, inputs, stats=None, **kwargs):
stats = stats or Stats()
stats[FRAMEWORK] = BQML_TRANSFORM_FRAMEWORK_NAME
stats[ENGINE] = BQML_TRANSFORM_FRAMEWORK_NAME
with stats.time(SESSION_RUN_TIME):
try:
return self._predictor.predict(inputs, **kwargs)
except Exception as e: # pylint: disable=broad-except
logging.exception(
"Exception during predicting with bqml model with transform clause."
)
raise PredictionError(
PredictionError.FAILED_TO_RUN_MODEL,
"Exception during predicting with bqml model with transform"
" clause: "
+ str(e),
) from e
def create_transform_predictor(model_path, **unused_kwargs):
"""Returns a prediction client for the corresponding transform model."""
logging.info(
"Downloading the transform model from %s to %s",
model_path,
LOCAL_MODEL_PATH,
)
copy_model_to_local(model_path, LOCAL_MODEL_PATH)
try:
return transform_predictor.Predictor.from_path(LOCAL_MODEL_PATH)
except Exception as e:
logging.exception("Exception during loading bqml transform model.")
raise PredictionError(
PredictionError.FAILED_TO_LOAD_MODEL,
"Exception during loading bqml model with transform clause: " + str(e),
) from e
def create_bqml_transform_model(model_path, unused_flags):
"""Returns a transform model from the given model_path."""
return BqmlTransformModel(
BqmlTransformClient(create_transform_predictor(model_path))
)

View File

@@ -0,0 +1,86 @@
# Copyright 2022 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions for BQML xgboost models."""
import logging
from bigquery_ml_utils.inference.xgboost_predictor import Predictor
from google.cloud.ml.prediction import copy_model_to_local
from google.cloud.ml.prediction import ENGINE
from google.cloud.ml.prediction import ENGINE_RUN_TIME
from google.cloud.ml.prediction import FRAMEWORK
from google.cloud.ml.prediction import LOCAL_MODEL_PATH
from google.cloud.ml.prediction import PredictionClient
from google.cloud.ml.prediction import PredictionError
from google.cloud.ml.prediction import SESSION_RUN_TIME
from google.cloud.ml.prediction import Stats
from google.cloud.ml.prediction.frameworks.sk_xg_prediction_lib import SklearnModel
BQML_XGBOOST_FRAMEWORK_NAME = "bqml-xgboost"
class BqmlXGBoostModel(SklearnModel):
"""The implementation of BQML's XGboost Model."""
def predict(self, instances, stats=None, **kwargs):
stats = stats or Stats()
with stats.time(ENGINE_RUN_TIME):
return self._client.predict(instances, stats=stats, **kwargs)
class BqmlXGBoostClient(PredictionClient):
"""The implementation of BQML's XGboost Client."""
def __init__(self, predictor):
self._predictor = predictor
def predict(self, inputs, stats=None, **kwargs):
stats = stats or Stats()
stats[FRAMEWORK] = BQML_XGBOOST_FRAMEWORK_NAME
stats[ENGINE] = BQML_XGBOOST_FRAMEWORK_NAME
with stats.time(SESSION_RUN_TIME):
try:
return self._predictor.predict(inputs, **kwargs)
except Exception as e: # pylint: disable=broad-except
logging.exception(
"Exception during predicting with bqml xgboost model."
)
raise PredictionError(
PredictionError.FAILED_TO_RUN_MODEL,
"Exception during predicting with bqml xgboost model: " + str(e),
) from e
def create_xgboost_predictor(model_path, **unused_kwargs):
"""Returns a prediction client for the corresponding xgboost model."""
logging.info(
"Downloading the xgboost model from %s to %s",
model_path,
LOCAL_MODEL_PATH,
)
copy_model_to_local(model_path, LOCAL_MODEL_PATH)
try:
return Predictor.from_path(LOCAL_MODEL_PATH)
except Exception as e:
logging.exception("Exception during loading bqml xgboost model.")
raise PredictionError(
PredictionError.FAILED_TO_LOAD_MODEL,
"Exception during loading bqml xgboost model: " + str(e),
) from e
def create_bqml_xgboost_model(model_path, unused_flags):
"""Returns a xgboost model from the given model_path."""
return BqmlXGBoostModel(
BqmlXGBoostClient(create_xgboost_predictor(model_path))
)

View File

@@ -0,0 +1,264 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions for sklearn and xgboost frameworks.
"""
import logging
import os
from .. import prediction_utils
from .._interfaces import PredictionClient
import numpy as np
from sklearn import linear_model
from ..prediction_utils import DEFAULT_MODEL_FILE_NAME_JOBLIB
from ..prediction_utils import DEFAULT_MODEL_FILE_NAME_PICKLE
from ..prediction_utils import load_joblib_or_pickle_model
from ..prediction_utils import PredictionError
# --------------------------
# prediction.frameworks.sk_xg_prediction_lib
# --------------------------
# Scikit-learn and XGBoost related constants
MODEL_FILE_NAME_BST = "model.bst"
# This class is specific to Scikit-learn, and should be moved to a separate
# module. However due to gcloud's complicated copying mechanism we need to keep
# things in one file for now.
class SklearnClient(PredictionClient):
"""A loaded scikit-learn model to be used for prediction."""
def __init__(self, predictor):
self._predictor = predictor
def predict(self, inputs, stats=None, **kwargs):
stats = stats or prediction_utils.Stats()
stats[prediction_utils.
FRAMEWORK] = prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME
stats[
prediction_utils.ENGINE] = prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME
with stats.time(prediction_utils.SESSION_RUN_TIME):
try:
return self._predictor.predict(inputs, **kwargs)
except Exception as e: # pylint: disable=broad-except
logging.exception("Exception while predicting with sklearn model.")
raise PredictionError(PredictionError.FAILED_TO_RUN_MODEL,
"Exception during sklearn prediction: " + str(e))
# (TODO:b/68775232) This class is specific to Xgboost, and should be moved to a
# separate module. However due to gcloud's complicated copying mechanism we need
# to keep things in one file for now.
class XgboostClient(PredictionClient):
"""A loaded xgboost model to be used for prediction."""
def __init__(self, booster):
self._booster = booster
def predict(self, inputs, stats=None, **kwargs):
stats = stats or prediction_utils.Stats()
stats[prediction_utils.FRAMEWORK] = prediction_utils.XGBOOST_FRAMEWORK_NAME
stats[prediction_utils.ENGINE] = prediction_utils.XGBOOST_FRAMEWORK_NAME
# TODO(user): Move this to the top once b/64574886 is resolved.
# Before then, it would work in production since we install xgboost in
# the Dockerfile, but the problem is the unit test that will fail to build
# and run since xgboost can not be added as a dependency to this target.
import xgboost as xgb # pylint: disable=g-import-not-at-top
try:
inputs_dmatrix = xgb.DMatrix(inputs)
except Exception as e: # pylint: disable=broad-except
logging.exception("Could not initialize DMatrix from inputs.")
raise PredictionError(
PredictionError.FAILED_TO_RUN_MODEL,
"Could not initialize DMatrix from inputs: " + str(e))
with stats.time(prediction_utils.SESSION_RUN_TIME):
try:
return self._booster.predict(inputs_dmatrix, **kwargs)
except Exception as e: # pylint: disable=broad-except
logging.exception("Exception during predicting with xgboost model: ")
raise PredictionError(PredictionError.FAILED_TO_RUN_MODEL,
"Exception during xgboost prediction: " + str(e))
class SklearnModel(prediction_utils.BaseModel):
"""The implementation of Scikit-learn Model.
"""
def predict(self, instances, stats=None, **kwargs):
"""Override the predict method to remove TF-specific args from kwargs."""
kwargs.pop(prediction_utils.SIGNATURE_KEY, None)
return super(SklearnModel, self).predict(instances, stats, **kwargs)
def preprocess(self, instances, stats=None, **kwargs):
return instances
def postprocess(self,
predicted_outputs,
original_input=None,
stats=None,
**kwargs):
if isinstance(predicted_outputs, np.ndarray):
return predicted_outputs.tolist()
if isinstance(predicted_outputs, list):
return predicted_outputs
raise PredictionError(
PredictionError.INVALID_OUTPUTS,
"Bad output type returned."
"The predict function should return either "
"a numpy ndarray or a list.")
class XGBoostModel(SklearnModel):
"""The implementation of XGboost Model."""
def preprocess(self, instances, stats=None, **kwargs):
return np.array(instances)
def create_sklearn_client(model_path, **unused_kwargs):
"""Returns a prediction client for the corresponding sklearn model."""
logging.info("Loading the scikit-learn model file from %s", model_path)
sklearn_predictor = load_joblib_or_pickle_model(model_path)
# Serialized LinearRegression models from sklearn v1.0.2 do not have the
# positive attribute. Patch this here to prevent errors downstream.
if (isinstance(sklearn_predictor, linear_model.LinearRegression)
and not hasattr(sklearn_predictor, "positive")):
sklearn_predictor.positive = False
if not sklearn_predictor:
error_msg = "Could not find either {} or {} in {}".format(
DEFAULT_MODEL_FILE_NAME_JOBLIB, DEFAULT_MODEL_FILE_NAME_PICKLE,
model_path)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
# Check if the loaded python object is an sklearn model/pipeline.
# Ex. type(sklearn_predictor).__module__ -> 'sklearn.svm.classes'
# type(pipeline).__module__ -> 'sklearn.pipeline'
if "sklearn" not in type(sklearn_predictor).__module__:
error_msg = ("Invalid model type detected: {}.{}. Please make sure the "
"model file is an exported sklearn model or pipeline.").format(
type(sklearn_predictor).__module__,
type(sklearn_predictor).__name__)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
return SklearnClient(sklearn_predictor)
def create_sklearn_model(model_path, unused_flags):
"""Returns a sklearn model from the given model_path."""
return SklearnModel(create_sklearn_client(model_path))
def create_xgboost_client(model_path, **unused_kwargs):
"""Returns a prediction client for the corresponding xgboost model."""
logging.info("Loading the xgboost model from %s", model_path)
# TODO(user): Copy model file to local to reduce copying operation.
booster = load_joblib_or_pickle_model(model_path) or _load_xgboost_model(
model_path)
if not booster:
error_msg = "Could not find {}, {}, or {} in {}".format(
DEFAULT_MODEL_FILE_NAME_JOBLIB, DEFAULT_MODEL_FILE_NAME_PICKLE,
MODEL_FILE_NAME_BST, model_path)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
# Check if the loaded python object is an xgboost model.
# Expect type(booster).__module__ -> 'xgboost.core'
if "xgboost" not in type(booster).__module__:
error_msg = ("Invalid model type detected: {}.{}. Please make sure the "
"model file is an exported xgboost model.").format(
type(booster).__module__,
type(booster).__name__)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
return XgboostClient(booster)
def _load_xgboost_model(model_path):
"""Loads an xgboost model from GCS or local.
Args:
model_path: path to the directory containing the xgboost model.bst file.
This path can be either a local path or a GCS path.
Returns:
A xgboost.Booster with the model at model_path loaded.
Raises:
PredictionError: If there is a problem while loading the file.
"""
# TODO(user): Move this to the top once b/64574886 is resolved. Before
# then, it would work in production since we install xgboost in the
# Dockerfile, but the problem is the unit test that will fail to build and run
# since xgboost can not be added as a dependency to this target.
import xgboost as xgb # pylint: disable=g-import-not-at-top
if model_path.startswith("gs://"):
prediction_utils.copy_model_to_local(model_path,
prediction_utils.LOCAL_MODEL_PATH)
model_path = prediction_utils.LOCAL_MODEL_PATH
model_file = os.path.join(model_path, MODEL_FILE_NAME_BST)
if not os.path.exists(model_file):
return None
try:
return xgb.Booster(model_file=model_file)
except xgb.core.XGBoostError as e:
error_msg = "Could not load the model: {}.".format(
os.path.join(model_path, MODEL_FILE_NAME_BST))
logging.exception(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
"{}. {}.".format(error_msg, str(e)))
def create_xgboost_model(model_path, unused_flags):
"""Returns a xgboost model from the given model_path."""
return XGBoostModel(create_xgboost_client(model_path))
def create_sk_xg_model(model_path, unused_flags):
"""Create xgboost model or sklearn model from the given model_path.
Args:
model_path: path to the directory containing only one of model.joblib or
model.pkl file. This path can be either a local path or a GCS path.
unused_flags: Required since model creation for other frameworks needs the
additional flags params. And model creation is called in a framework
agnostic manner.
Returns:
A xgboost model or sklearn model
"""
# detect framework in ambiguous situations.
model_obj = load_joblib_or_pickle_model(model_path)
# Serialized LinearRegression models from sklearn v1.0.2 do not have the
# positive attribute. Patch this here to prevent errors downstream.
if (isinstance(model_obj, linear_model.LinearRegression)
and not hasattr(model_obj, "positive")):
model_obj.positive = False
framework = prediction_utils.detect_sk_xgb_framework_from_obj(model_obj)
if framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
return SklearnModel(SklearnClient(model_obj))
elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
return XGBoostModel(XgboostClient(model_obj))
else:
error_msg = (
"Invalid framework detected: {}. Please make sure the model file is "
"supported by either scikit-learn or xgboost."
).format(framework)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)

View File

@@ -0,0 +1,632 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions for TF framework.
Note that we avoid importing tensorflow and tensorflow.contrib at the top.
This is because this module gets loaded for other frameworks as well,
and loading xgboost after tensorflow.contrib causes an error.
More context: b/71906188#comment20.
"""
import base64
import collections
import logging
import os
from .. import prediction_utils
from .._interfaces import PredictionClient
import numpy as np
from ..prediction_utils import PredictionError
import six
import tensorflow as tf
# pylint: disable=g-import-not-at-top
# Conditionally import files based on whether this is TF 2.x or TF 1.x.
# A direct check for tf.__version__ fails in some cases, so using the
# hammer of `try`/`catch` blocks instead.
try:
# tf.dtypes and tf.compat weren't added until later versions of TF.
# These imports and constants work for all TF 1.X.
from tensorflow.python.util import compat # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import dtypes # pylint: disable=g-direct-tensorflow-import
SERVING = tf.saved_model.tag_constants.SERVING
DEFAULT_SERVING_SIGNATURE_DEF_KEY = (
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
# Force Tensorflow contrib to load in order to provide access to all the
# libraries in contrib to batch prediction (also, when using SESSION_RUN
# instead of MODEL_SERVER for online prediction, which we no longer do).
# However, contrib is no longer a part of TensorFlow 2.0, so check for its
# existence first.
try:
import tensorflow.contrib # pylint: disable=unused-import
# TF 1.15 introduced lazy loading for tensorflow.contrib, but doing
# a dir forces it to load.
dir(tensorflow.contrib)
except: # pylint: disable=bare-except
pass
except: # pylint: disable=bare-except
import tensorflow.compat.v1 as tf
from tensorflow import dtypes
from tensorflow import compat
SERVING = tf.saved_model.SERVING
DEFAULT_SERVING_SIGNATURE_DEF_KEY = (
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
tf.disable_v2_behavior()
# pylint: enable=g-import-not-at-top
# --------------------------
# prediction.frameworks.tf_prediction_lib
# --------------------------
_CUSTOM_OP_DIRECTORY_NAME = "assets.extra"
_CUSTOM_OP_SUFFIX = "*.so"
_CUSTOM_OP_LOCAL_DIR = "/tmp/custom_ops/"
def columnarize(instances):
"""Columnarize inputs.
Each line in the input is a dictionary of input names to the value
for that input (a single instance). For each input "column", this method
appends each of the input values to a list. The result is a dict mapping
input names to a batch of input data. This can be directly used as the
feed dict during prediction.
For example,
instances = [{"a": [1.0, 2.0], "b": "a"},
{"a": [3.0, 4.0], "b": "c"},
{"a": [5.0, 6.0], "b": "e"},]
batch = prediction_server_lib.columnarize(instances)
assert batch == {"a": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
"b": ["a", "c", "e"]}
Arguments:
instances: (list of dict) where the dictionaries map input names
to the values for those inputs.
Returns:
A dictionary mapping input names to values, as described above.
"""
columns = collections.defaultdict(list)
for instance in instances:
for k, v in six.iteritems(instance):
columns[k].append(v)
return columns
def rowify(columns):
"""Converts columnar input to row data.
Consider the following code:
columns = {"prediction": np.array([1, # 1st instance
0, # 2nd
1]), # 3rd
"scores": np.array([[0.1, 0.9], # 1st instance
[0.7, 0.3], # 2nd
[0.4, 0.6]])} # 3rd
Then rowify will return the equivalent of:
[{"prediction": 1, "scores": [0.1, 0.9]},
{"prediction": 0, "scores": [0.7, 0.3]},
{"prediction": 1, "scores": [0.4, 0.6]}]
(each row is yielded; no list is actually created).
Arguments:
columns: (dict) mapping names to numpy arrays, where the arrays
contain a batch of data.
Raises:
PredictionError: if the outer dimension of each input isn't identical
for each of element.
Yields:
A map with a single instance, as described above. Note: instances
is not a numpy array.
"""
sizes_set = {e.shape[0] for e in six.itervalues(columns)}
# All the elements in the length array should be identical. Otherwise,
# raise an exception.
if len(sizes_set) != 1:
sizes_dict = {name: e.shape[0] for name, e in six.iteritems(columns)}
raise PredictionError(
PredictionError.INVALID_OUTPUTS,
"Bad output from running tensorflow session: outputs had differing "
"sizes in the batch (outer) dimension. See the outputs and their "
"size: %s. Check your model for bugs that effect the size of the "
"outputs." % sizes_dict)
# Pick an arbitrary value in the map to get its size.
num_instances = len(next(six.itervalues(columns)))
for row in six.moves.xrange(num_instances):
yield {
name: output[row, ...].tolist()
for name, output in six.iteritems(columns)
}
def canonicalize_single_tensor_input(instances, tensor_name):
"""Canonicalize single input tensor instances into list of dicts.
Instances that are single input tensors may or may not be provided with their
tensor name. The following are both valid instances:
1) instances = [{"x": "a"}, {"x": "b"}, {"x": "c"}]
2) instances = ["a", "b", "c"]
This function canonicalizes the input instances to be of type 1).
Arguments:
instances: single input tensor instances as supplied by the user to the
predict method.
tensor_name: the expected name of the single input tensor.
Raises:
PredictionError: if the wrong tensor name is supplied to instances.
Returns:
A list of dicts. Where each dict is a single instance, mapping the
tensor_name to the value (as supplied by the original instances).
"""
# Input is a single string tensor, the tensor name might or might not
# be given.
# There are 3 cases (assuming the tensor name is "t", tensor = "abc"):
# 1) {"t": "abc"}
# 2) "abc"
# 3) {"y": ...} --> wrong tensor name is given.
def parse_single_tensor(x, tensor_name):
if not isinstance(x, dict):
# case (2)
return {tensor_name: x}
elif len(x) == 1 and tensor_name == list(x.keys())[0]:
# case (1)
return x
else:
raise PredictionError(PredictionError.INVALID_INPUTS,
"Expected tensor name: %s, got tensor name: %s." %
(tensor_name, list(x.keys())))
if not isinstance(instances, list):
instances = [instances]
instances = [parse_single_tensor(x, tensor_name) for x in instances]
return instances
# TODO(user): when we no longer load the model to get the signature
# consider making this a named constructor on SessionClient.
def load_tf_model(model_path,
tags=(SERVING,),
config=None):
"""Loads the model at the specified path.
Args:
model_path: the path to either session_bundle or SavedModel
tags: the tags that determines the model to load.
config: tf.ConfigProto containing session configuration options.
Returns:
A pair of (Session, map<string, SignatureDef>) objects.
Raises:
PredictionError: if the model could not be loaded.
"""
_load_tf_custom_op(model_path)
if tf.saved_model.loader.maybe_saved_model_directory(model_path):
try:
logging.info("Importing tensorflow.contrib in load_tf_model")
if tf.__version__.startswith("1.0"):
session = tf.Session(target="", graph=None, config=config)
else:
session = tf.Session(target="", graph=tf.Graph(), config=config)
meta_graph = tf.saved_model.loader.load(
session, tags=list(tags), export_dir=model_path)
except Exception as e: # pylint: disable=broad-except
msg = ("Failed to load the model due to bad model data. "
"tags: %s" % (list(tags),))
logging.exception(msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
"%s\n%s" % (msg, str(e)))
else:
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
"Cloud ML only supports TF 1.0 or above and models "
"saved in SavedModel format.")
if session is None:
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
"Failed to create session when loading the model")
if not meta_graph.signature_def:
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
"MetaGraph must have at least one signature_def.")
# Remove invalid signatures from the signature map.
invalid_signatures = []
for signature_name in meta_graph.signature_def:
try:
signature = meta_graph.signature_def[signature_name]
_update_dtypes(session.graph, signature.inputs)
_update_dtypes(session.graph, signature.outputs)
except ValueError as e:
logging.warn("Error updating signature %s: %s", signature_name, str(e))
invalid_signatures.append(signature_name)
for signature_name in invalid_signatures:
del meta_graph.signature_def[signature_name]
return session, meta_graph.signature_def
def _update_dtypes(graph, interface):
"""Adds dtype to TensorInfos in interface if necessary.
If already present, validates TensorInfo matches values in the graph.
TensorInfo is updated in place.
Args:
graph: the TensorFlow graph; used to lookup datatypes of tensors.
interface: map from alias to TensorInfo object.
Raises:
ValueError: if the data type in the TensorInfo does not match the type
found in graph.
"""
for alias, info in six.iteritems(interface):
# Postpone conversion to enum for better error messages.
dtype = graph.get_tensor_by_name(info.name).dtype
if not info.dtype:
info.dtype = dtype.as_datatype_enum
elif info.dtype != dtype.as_datatype_enum:
raise ValueError("Specified data types do not match for alias %s. "
"Graph has %d while TensorInfo reports %d." %
(alias, dtype, info.dtype))
# (TODO:b/68775232): Move this to a Tensorflow specific library.
class TensorFlowClient(PredictionClient):
"""A client for Prediction that uses Session.run."""
def __init__(self, signature_map, *args, **kwargs):
self._signature_map = signature_map
super(TensorFlowClient, self).__init__(*args, **kwargs)
@property
def signature_map(self):
return self._signature_map
def get_signature(self, signature_name=None):
"""Gets tensorflow signature for the given signature_name.
Args:
signature_name: string The signature name to use to choose the signature
from the signature map.
Returns:
a pair of signature_name and signature. The first element is the
signature name in string that is actually used. The second one is the
signature.
Raises:
PredictionError: when the signature is not found with the given signature
name or when there are more than one signatures in the signature map.
"""
# The way to find signature is:
# 1) if signature_name is specified, try to find it in the signature_map. If
# not found, raise an exception.
# 2) if signature_name is not specified, check if signature_map only
# contains one entry. If so, return the only signature.
# 3) Otherwise, use the default signature_name and do 1).
if not signature_name and len(self.signature_map) == 1:
return (list(self.signature_map.keys())[0],
list(self.signature_map.values())[0])
key = (signature_name or DEFAULT_SERVING_SIGNATURE_DEF_KEY)
if key in self.signature_map:
return key, self.signature_map[key]
else:
raise PredictionError(
PredictionError.INVALID_INPUTS,
"No signature found for signature key %s." % signature_name)
class SessionClient(TensorFlowClient):
"""A client for Prediction that uses Session.run."""
def __init__(self, session, signature_map):
self._session = session
super(SessionClient, self).__init__(signature_map)
def predict(self, inputs, stats=None,
signature_name=None, **unused_kwargs):
"""Produces predictions for the given inputs.
Args:
inputs: a dict mapping input names to values
stats: Stats object for recording timing information.
signature_name: name of SignatureDef to use in this prediction
**unused_kwargs: placeholder, pre/postprocess may have additional args
Returns:
A dict mapping output names to output values, similar to the input
dict.
"""
stats = stats or prediction_utils.Stats()
stats[prediction_utils.ENGINE] = "SessionRun"
stats[
prediction_utils.FRAMEWORK] = prediction_utils.TENSORFLOW_FRAMEWORK_NAME
with stats.time(prediction_utils.UNALIAS_TIME):
_, signature = self.get_signature(signature_name)
fetches = [output.name for output in signature.outputs.values()]
try:
unaliased = {
signature.inputs[key].name: val
for key, val in six.iteritems(inputs)
}
except Exception as e: # pylint: disable=broad-except
logging.exception("Input mismatch.")
raise PredictionError(PredictionError.INVALID_INPUTS,
"Input mismatch: " + str(e))
with stats.time(prediction_utils.SESSION_RUN_TIME):
try:
# TODO(user): measure the actual session.run() time, even in the
# case of ModelServer.
outputs = self._session.run(fetches=fetches, feed_dict=unaliased)
except Exception as e: # pylint: disable=broad=except
logging.exception("Exception running the graph.")
raise PredictionError(PredictionError.FAILED_TO_RUN_MODEL,
"Exception during running the graph: " + str(e))
with stats.time(prediction_utils.ALIAS_TIME):
return dict(zip(six.iterkeys(signature.outputs), outputs))
class TensorFlowModel(prediction_utils.BaseModel):
"""The default implementation of the Model interface that uses TensorFlow.
This implementation optionally performs preprocessing and postprocessing
using the provided functions. These functions accept a single instance
as input and produce a corresponding output to send to the prediction
client.
"""
def _get_columns(self, instances, stats, signature):
"""Columnarize the instances, appending input_name, if necessary.
Instances are the same instances passed to the predict() method. Since
models with a single input can accept the raw input without the name,
we create a dict here with that name.
This list of instances is then converted into a column-oriented format:
The result is a dictionary mapping input name to a list of values for just
that input (one entry per row in the original instances list).
Args:
instances: the list of instances as provided to the predict() method.
stats: Stats object for recording timing information.
signature: SignatureDef for the current request.
Returns:
A dictionary mapping input names to their values.
Raises:
PredictionError: if an error occurs during prediction.
"""
with stats.time(prediction_utils.COLUMNARIZE_TIME):
columns = columnarize(instances)
for k, v in six.iteritems(columns):
if k not in signature.inputs.keys():
raise PredictionError(
PredictionError.INVALID_INPUTS,
"Unexpected tensor name: %s" % k)
# Detect whether or not the user omits an input in one or more inputs.
# TODO(user): perform this check in columnarize?
if isinstance(v, list) and len(v) != len(instances):
raise PredictionError(
PredictionError.INVALID_INPUTS,
"Input %s was missing in at least one input instance." % k)
return columns
# TODO(user): can this be removed?
def is_single_input(self, signature):
"""Returns True if the graph only has one input tensor."""
return len(signature.inputs) == 1
# TODO(user): can this be removed?
def is_single_string_input(self, signature):
"""Returns True if the graph only has one string input tensor."""
if self.is_single_input(signature):
dtype = list(signature.inputs.values())[0].dtype
return dtype == dtypes.string.as_datatype_enum
return False
def get_signature(self, signature_name=None):
return self._client.get_signature(signature_name)
def preprocess(self, instances, stats=None, signature_name=None, **kwargs):
_, signature = self.get_signature(signature_name)
preprocessed = self._canonicalize_input(instances, signature)
return self._get_columns(preprocessed, stats, signature)
def _canonicalize_input(self, instances, signature):
"""Preprocess single-input instances to be dicts if they aren't already."""
# The instances should be already (b64-) decoded here.
if not self.is_single_input(signature):
return instances
tensor_name = list(signature.inputs.keys())[0]
return canonicalize_single_tensor_input(instances, tensor_name)
def postprocess(self, predicted_output, original_input=None, stats=None,
signature_name=None, **kwargs):
"""Performs the necessary transformations on the prediction results.
The transformations include rowifying the predicted results, and also
making sure that each input/output is a dict mapping input/output alias to
the value for that input/output.
Args:
predicted_output: list of instances returned by the predict() method on
preprocessed instances.
original_input: List of instances, before any pre-processing was applied.
stats: Stats object for recording timing information.
signature_name: the signature name to find out the signature.
**kwargs: Additional keyword arguments for postprocessing
Returns:
A list which is a dict mapping output alias to the output.
"""
_, signature = self.get_signature(signature_name)
with stats.time(prediction_utils.ROWIFY_TIME):
# When returned element only contains one result (batch size == 1),
# tensorflow's session.run() will return a scalar directly instead of a
# a list. So we need to listify that scalar.
# TODO(user): verify this behavior is correct.
def listify(value):
if not hasattr(value, "shape"):
return np.asarray([value], dtype=object)
elif not value.shape:
# TODO(user): pretty sure this is a bug that only exists because
# samples like iris have a bug where they use tf.squeeze which removes
# the batch dimension. The samples should be fixed.
return np.expand_dims(value, axis=0)
else:
return value
postprocessed_outputs = {
alias: listify(val)
for alias, val in six.iteritems(predicted_output)
}
postprocessed_outputs = rowify(postprocessed_outputs)
postprocessed_outputs = list(postprocessed_outputs)
with stats.time(prediction_utils.ENCODE_TIME):
try:
postprocessed_outputs = encode_base64(
postprocessed_outputs, signature.outputs)
except PredictionError as e:
logging.exception("Encode base64 failed.")
raise PredictionError(PredictionError.INVALID_OUTPUTS,
"Prediction failed during encoding instances: {0}"
.format(e.error_detail))
except ValueError as e:
logging.exception("Encode base64 failed.")
raise PredictionError(PredictionError.INVALID_OUTPUTS,
"Prediction failed during encoding instances: {0}"
.format(e))
except Exception as e: # pylint: disable=broad-except
logging.exception("Encode base64 failed.")
raise PredictionError(PredictionError.INVALID_OUTPUTS,
"Prediction failed during encoding instances")
return postprocessed_outputs
@classmethod
def from_client(cls, client, unused_model_path, **unused_kwargs):
"""Creates a TensorFlowModel from a SessionClient and model data files."""
return cls(client)
@property
def signature_map(self):
return self._client.signature_map
def create_tf_session_client(model_dir,
tags=(SERVING,),
config=None):
return SessionClient(*load_tf_model(model_dir, tags, config))
def encode_base64(instances, outputs_map):
"""Encodes binary data in a JSON-friendly way."""
if not isinstance(instances, list):
raise ValueError("only lists allowed in output; got %s" %
(type(instances),))
if not instances:
return instances
first_value = instances[0]
if not isinstance(first_value, dict):
if len(outputs_map) != 1:
return ValueError("The first instance was a string, but there are "
"more than one output tensor, so dict expected.")
# Only string tensors whose name ends in _bytes needs encoding.
tensor_name, tensor_info = next(iter(outputs_map.items()))
tensor_type = tensor_info.dtype
if tensor_type == dtypes.string:
instances = _encode_str_tensor(instances, tensor_name)
return instances
encoded_data = []
for instance in instances:
encoded_instance = {}
for tensor_name, tensor_info in six.iteritems(outputs_map):
tensor_type = tensor_info.dtype
tensor_data = instance[tensor_name]
if tensor_type == dtypes.string:
tensor_data = _encode_str_tensor(tensor_data, tensor_name)
encoded_instance[tensor_name] = tensor_data
encoded_data.append(encoded_instance)
return encoded_data
def _encode_str_tensor(data, tensor_name):
"""Encodes tensor data of type string.
Data is a bytes in python 3 and a string in python 2. Base 64 encode the data
if the tensorname ends in '_bytes', otherwise convert data to a string.
Args:
data: Data of the tensor, type bytes in python 3, string in python 2.
tensor_name: The corresponding name of the tensor.
Returns:
JSON-friendly encoded version of the data.
"""
if isinstance(data, list):
return [_encode_str_tensor(val, tensor_name) for val in data]
if tensor_name.endswith("_bytes"):
return {"b64": compat.as_text(base64.b64encode(data))}
else:
return compat.as_text(data)
def _load_tf_custom_op(model_path):
"""Loads a custom TF OP (in .so format) from /assets.extra directory."""
assets_dir = os.path.join(model_path, _CUSTOM_OP_DIRECTORY_NAME)
if tf.gfile.IsDirectory(assets_dir):
custom_ops_pattern = os.path.join(assets_dir, _CUSTOM_OP_SUFFIX)
for custom_op_path_original in tf.gfile.Glob(custom_ops_pattern):
logging.info("Found custom op file: %s", custom_op_path_original)
if custom_op_path_original.startswith("gs://"):
if not os.path.isdir(_CUSTOM_OP_LOCAL_DIR):
os.makedirs(_CUSTOM_OP_LOCAL_DIR)
custom_op_path_local = os.path.join(
_CUSTOM_OP_LOCAL_DIR, os.path.basename(custom_op_path_original))
logging.info("Copying custop op from: %s to: %s",
custom_op_path_original, custom_op_path_local)
tf.gfile.Copy(custom_op_path_original, custom_op_path_local, True)
else:
custom_op_path_local = custom_op_path_original
try:
logging.info("Loading custom op: %s", custom_op_path_local)
logging.info("TF Version: %s", tf.__version__)
tf.load_op_library(custom_op_path_local)
except RuntimeError as e:
logging.exception(
"Failed to load custom op: %s with error: %s. Prediction "
"will likely fail due to missing operations.", custom_op_path_local,
e)

View File

@@ -0,0 +1,103 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions.
Includes (from the Cloud ML SDK):
- _predict_lib
Important changes:
- Remove interfaces for TensorFlowModel (they don't change behavior).
- Set from_client(skip_preprocessing=True) and remove the pre-processing code.
"""
from . import custom_code_utils
from . import prediction_utils
# --------------------------
# prediction.prediction_lib
# --------------------------
def create_model(client, model_path, framework=None, **unused_kwargs):
"""Creates and returns the appropriate model.
Creates and returns a Model if no user specified model is
provided. Otherwise, the user specified model is imported, created, and
returned.
Args:
client: An instance of PredictionClient for performing prediction.
model_path: The path to the exported model (e.g. session_bundle or
SavedModel)
framework: The framework used to train the model.
Returns:
An instance of the appropriate model class.
"""
custom_model = custom_code_utils.create_user_model(model_path, None)
if custom_model:
return custom_model
framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME:
from .frameworks import tf_prediction_lib # pylint: disable=g-import-not-at-top
model_cls = tf_prediction_lib.TensorFlowModel
elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top
model_cls = sk_xg_prediction_lib.SklearnModel
elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top
model_cls = sk_xg_prediction_lib.XGBoostModel
return model_cls(client)
def create_client(framework, model_path, **kwargs):
"""Creates and returns the appropriate prediction client.
Creates and returns a PredictionClient based on the provided framework.
Args:
framework: The framework used to train the model.
model_path: The path to the exported model (e.g. session_bundle or
SavedModel)
**kwargs: Optional additional params to pass to the client constructor (such
as TF tags).
Returns:
An instance of the appropriate PredictionClient.
"""
framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME:
from .frameworks import tf_prediction_lib # pylint: disable=g-import-not-at-top
create_client_fn = tf_prediction_lib.create_tf_session_client
elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top
create_client_fn = sk_xg_prediction_lib.create_sklearn_client
elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
from .frameworks import sk_xg_prediction_lib # pylint: disable=g-import-not-at-top
create_client_fn = sk_xg_prediction_lib.create_xgboost_client
return create_client_fn(model_path, **kwargs)
def local_predict(model_dir=None, signature_name=None, instances=None,
framework=None, **kwargs):
"""Run a prediction locally."""
framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
client = create_client(framework, model_dir, **kwargs)
model = create_model(client, model_dir, framework)
if prediction_utils.should_base64_decode(framework, model, signature_name):
instances = prediction_utils.decode_base64(instances)
predictions = model.predict(instances, signature_name=signature_name)
return {"predictions": list(predictions)}

View File

@@ -0,0 +1,651 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for running predictions."""
import base64
import collections
import contextlib
import json
import logging
import os
import pickle
import subprocess
import sys
import time
import timeit
from ._interfaces import Model
import six
from tensorflow.python.framework import dtypes # pylint: disable=g-direct-tensorflow-import
collections_lib = collections
if sys.version_info > (3, 8):
collections_lib = collections.abc
# --------------------------
# prediction.common
# --------------------------
ENGINE = "Prediction-Engine"
ENGINE_RUN_TIME = "Prediction-Engine-Run-Time"
FRAMEWORK = "Framework"
MODEL_SUBDIRECTORY = "model"
PREPARED_MODEL_SUBDIRECTORY = "prepared_model"
SCIKIT_LEARN_FRAMEWORK_NAME = "scikit_learn"
SK_XGB_FRAMEWORK_NAME = "sk_xgb"
XGBOOST_FRAMEWORK_NAME = "xgboost"
TENSORFLOW_FRAMEWORK_NAME = "tensorflow"
CUSTOM_FRAMEWORK_NAME = "custom"
PREPROCESS_TIME = "Prediction-Preprocess-Time"
POSTPROCESS_TIME = "Prediction-Postprocess-Time"
# Default model names
DEFAULT_MODEL_FILE_NAME_JOBLIB = "model.joblib"
DEFAULT_MODEL_FILE_NAME_PICKLE = "model.pkl"
TENSORFLOW_SPECIFIC_MODEL_FILE_NAMES = (
"saved_model.pb",
"saved_model.pbtxt",
)
SCIKIT_LEARN_MODEL_FILE_NAMES = (
DEFAULT_MODEL_FILE_NAME_JOBLIB,
DEFAULT_MODEL_FILE_NAME_PICKLE,
)
XGBOOST_SPECIFIC_MODEL_FILE_NAMES = ("model.bst",)
# Additional TF keyword arguments
INPUTS_KEY = "inputs"
OUTPUTS_KEY = "outputs"
SIGNATURE_KEY = "signature_name"
# Stats
COLUMNARIZE_TIME = "Prediction-Columnarize-Time"
UNALIAS_TIME = "Prediction-Unalias-Time"
ENCODE_TIME = "Prediction-Encode-Time"
SESSION_RUN_TIME = "Prediction-Session-Run-Time"
ALIAS_TIME = "Prediction-Alias-Time"
ROWIFY_TIME = "Prediction-Rowify-Time"
# TODO(user): Consider removing INPUT_PROCESSING_TIME during cleanup.
SESSION_RUN_ENGINE_NAME = "TF_SESSION_RUN"
# Location of where model files are copied locally.
LOCAL_MODEL_PATH = "/tmp/model"
PredictionErrorType = collections.namedtuple(
"PredictionErrorType", ("message", "code"))
# Keys related to requests and responses to prediction server.
PREDICTIONS_KEY = "predictions"
OUTPUTS_KEY = "outputs"
INSTANCES_KEY = "instances"
class PredictionError(Exception):
"""Customer exception for known prediction exception."""
# The error code for prediction.
FAILED_TO_LOAD_MODEL = PredictionErrorType(
message="Failed to load model", code=0)
INVALID_INPUTS = PredictionErrorType("Invalid inputs", code=1)
FAILED_TO_RUN_MODEL = PredictionErrorType(
message="Failed to run the provided model", code=2)
INVALID_OUTPUTS = PredictionErrorType(
message="There was a problem processing the outputs", code=3)
INVALID_USER_CODE = PredictionErrorType(
message="There was a problem processing the user code", code=4)
FAILED_TO_ACCESS_METADATA_SERVER = PredictionErrorType(
message="Could not get an access token from the metadata server",
code=5)
# When adding new exception, please update the ERROR_MESSAGE_ list as well as
# unittest.
@property
def error_code(self):
return self.args[0].code
@property
def error_message(self):
return self.args[0].message
@property
def error_detail(self):
return self.args[1]
def __str__(self):
return ("%s: %s (Error code: %d)" % (self.error_message,
self.error_detail, self.error_code))
MICRO = 1000000
MILLI = 1000
class Timer(object):
"""Context manager for timing code blocks.
The object is intended to be used solely as a context manager and not
as a general purpose object.
The timer starts when __enter__ is invoked on the context manager
and stopped when __exit__ is invoked. After __exit__ is called,
the duration properties report the amount of time between
__enter__ and __exit__ and thus do not change. However, if any of the
duration properties are called between the call to __enter__ and __exit__,
then they will return the "live" value of the timer.
If the same Timer object is re-used in multiple with statements, the values
reported will reflect the latest call. Do not use the same Timer object in
nested with blocks with the same Timer context manager.
Example usage:
with Timer() as timer:
foo()
print(timer.duration_secs)
"""
def __init__(self, timer_fn=None):
self.start = None
self.end = None
self._get_time = timer_fn or timeit.default_timer
def __enter__(self):
self.end = None
self.start = self._get_time()
return self
def __exit__(self, exc_type, value, traceback):
self.end = self._get_time()
return False
@property
def seconds(self):
now = self._get_time()
return (self.end or now) - (self.start or now)
@property
def microseconds(self):
return int(MICRO * self.seconds)
@property
def milliseconds(self):
return int(MILLI * self.seconds)
class Stats(dict):
"""An object for tracking stats.
This class is dict-like, so stats are accessed/stored like so:
stats = Stats()
stats["count"] = 1
stats["foo"] = "bar"
This class also facilitates collecting timing information via the
context manager obtained using the "time" method. Reported timings
are in microseconds.
Example usage:
with stats.time("foo_time"):
foo()
print(stats["foo_time"])
"""
@contextlib.contextmanager
def time(self, name, timer_fn=None):
with Timer(timer_fn) as timer:
yield timer
self[name] = timer.microseconds
class BaseModel(Model):
"""The base definition of an internal Model interface."""
def __init__(self, client):
"""Constructs a BaseModel.
Args:
client: An instance of PredictionClient for performing prediction.
"""
self._client = client
self._user_processor = None
def preprocess(self, instances, stats=None, **kwargs):
"""Runs the preprocessing function on the instances.
Args:
instances: list of instances as provided to the predict() method.
stats: Stats object for recording timing information.
**kwargs: Additional keyword arguments for preprocessing.
Returns:
A new list of preprocessed instances. Each instance is as described
in the predict() method.
"""
pass
def postprocess(self, predicted_output, original_input=None, stats=None,
**kwargs):
"""Runs the postprocessing function on the instances.
Args:
predicted_output: list of instances returned by the predict() method on
preprocessed instances.
original_input: List of instances, before any pre-processing was applied.
stats: Stats object for recording timing information.
**kwargs: Additional keyword arguments for postprocessing.
Returns:
A new list of postprocessed instances.
"""
pass
def predict(self, instances, stats=None, **kwargs):
"""Runs preprocessing, predict, and postprocessing on the input."""
stats = stats or Stats()
self._validate_kwargs(kwargs)
with stats.time(PREPROCESS_TIME):
preprocessed = self.preprocess(instances, stats=stats, **kwargs)
with stats.time(ENGINE_RUN_TIME):
predicted_outputs = self._client.predict(
preprocessed, stats=stats, **kwargs)
with stats.time(POSTPROCESS_TIME):
postprocessed = self.postprocess(
predicted_outputs, original_input=instances, stats=stats, **kwargs)
return postprocessed
def _validate_kwargs(self, kwargs):
"""Validates and sets defaults for extra predict keyword arguments.
Modifies the keyword args dictionary in-place. Keyword args will be included
into pre/post-processing and the client predict method.
Can raise Exception to error out of request on bad keyword args.
If no additional args are required, pass.
Args:
kwargs: Dictionary (str->str) of keyword arguments to check.
"""
pass
def get_signature(self, signature_name=None):
"""Gets model signature of inputs and outputs.
Currently only used for Tensorflow model. May be extended for use with
XGBoost and Sklearn in the future.
Args:
signature_name: str of name of signature
Returns:
(str, SignatureDef): signature key, SignatureDef
"""
return None, None
def should_base64_decode(framework, model, signature_name):
"""Determines if base64 decoding is required.
Returns False if framework is not TF.
Returns True if framework is TF and is a user model.
Returns True if framework is TF and model contains a str input.
Returns False if framework is TF and model does not contain str input.
Args:
framework: ML framework of prediction app
model: model object
signature_name: str of name of signature
Returns:
bool
"""
return (framework == TENSORFLOW_FRAMEWORK_NAME and
(not isinstance(model, BaseModel) or
does_signature_contain_str(model.get_signature(signature_name)[1])))
def decode_base64(data):
if isinstance(data, list):
return [decode_base64(val) for val in data]
elif isinstance(data, dict):
if six.viewkeys(data) == {"b64"}:
return base64.b64decode(data["b64"])
else:
return {k: decode_base64(v) for k, v in six.iteritems(data)}
else:
return data
def does_signature_contain_str(signature=None):
"""Return true if input signature contains a string dtype.
This is used to determine if we should proceed with base64 decoding.
Args:
signature: SignatureDef protocol buffer
Returns:
bool
"""
# if we did not receive a signature we assume the model could require
# a string in it's input
if signature is None:
return True
return any(v.dtype == dtypes.string.as_datatype_enum
for v in signature.inputs.values())
def copy_model_to_local(gcs_path, dest_path):
"""Copy files from gcs to a local path.
Copies files directly to the dest_path.
Sample behavior:
dir1/
file1
file2
dir2/
file3
copy_model_to_local("dir1", "/tmp")
After copy:
tmp/
file1
file2
dir2/
file3
Args:
gcs_path: Source GCS path that we're copying from.
dest_path: Destination local path that we're copying to.
Raises:
Exception: If gsutil is not found.
"""
copy_start_time = time.time()
logging.debug("Starting to copy files from %s to %s", gcs_path, dest_path)
if not os.path.exists(dest_path):
os.makedirs(dest_path)
gcs_path = os.path.join(gcs_path, "*")
try:
# Removed parallel downloads ("-m") because it was not working well in
# gVisor (b/37269226).
subprocess.check_call([
"gsutil", "cp", "-R", gcs_path, dest_path], stdin=subprocess.PIPE)
except subprocess.CalledProcessError:
logging.exception("Could not copy model using gsutil.")
raise
logging.debug("Files copied from %s to %s: took %f seconds", gcs_path,
dest_path, time.time() - copy_start_time)
def load_joblib_or_pickle_model(model_path):
"""Loads either a .joblib or .pkl file from GCS or from local.
Loads one of DEFAULT_MODEL_FILE_NAME_JOBLIB or DEFAULT_MODEL_FILE_NAME_PICKLE
files if they exist. This is used for both sklearn and xgboost.
Arguments:
model_path: The path to the directory that contains the model file. This
path can be either a local path or a GCS path.
Raises:
PredictionError: If there is a problem while loading the file.
Returns:
A loaded scikit-learn or xgboost predictor object or None if neither
DEFAULT_MODEL_FILE_NAME_JOBLIB nor DEFAULT_MODEL_FILE_NAME_PICKLE files are
found.
"""
if model_path.startswith("gs://"):
copy_model_to_local(model_path, LOCAL_MODEL_PATH)
model_path = LOCAL_MODEL_PATH
try:
model_file_name_joblib = os.path.join(model_path,
DEFAULT_MODEL_FILE_NAME_JOBLIB)
model_file_name_pickle = os.path.join(model_path,
DEFAULT_MODEL_FILE_NAME_PICKLE)
if os.path.exists(model_file_name_joblib):
model_file_name = model_file_name_joblib
try:
# Load joblib only when needed. If we put this at the top, we need to
# add a dependency to sklearn anywhere that prediction_lib is called.
from sklearn.externals import joblib # pylint: disable=g-import-not-at-top
except Exception as e: # pylint: disable=broad-except
try:
# Load joblib only when needed. If we put this at the top, we need to
# add a dependency to joblib anywhere that prediction_lib is called.
import joblib # pylint: disable=g-import-not-at-top
except Exception as e: # pylint: disable=broad-except
error_msg = "Could not import joblib module."
logging.exception(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
try:
logging.info("Loading model %s using joblib.", model_file_name)
return joblib.load(model_file_name)
except KeyError:
logging.info(
("Loading model %s using joblib failed. Loading model using "
"xgboost.Booster instead."), model_file_name)
# Load xgboost only when needed. If we put this at the top, we need to
# add a dependency to xgboost anywhere that prediction_lib is called.
import xgboost # pylint: disable=g-import-not-at-top
booster = xgboost.Booster()
booster.load_model(model_file_name)
return booster
elif os.path.exists(model_file_name_pickle):
model_file_name = model_file_name_pickle
logging.info("Loading model %s using pickle.", model_file_name)
with open(model_file_name, "rb") as f:
return pickle.loads(f.read())
return None
except Exception as e: # pylint: disable=broad-except
raw_error_msg = str(e)
if "unsupported pickle protocol" in raw_error_msg:
error_msg = (
"Could not load the model: {}. {}. Please make sure the model was "
"exported using python {}. Otherwise, please specify the correct "
"'python_version' parameter when deploying the model.").format(
model_file_name, raw_error_msg, sys.version_info[0])
else:
error_msg = "Could not load the model: {}. {}.".format(
model_file_name, raw_error_msg)
logging.exception(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
def detect_sk_xgb_framework_from_obj(model_obj):
"""Distinguish scikit-learn and xgboost using model object.
Arguments:
model_obj: A loaded model object
Raises:
PredictionError: If there is a problem detecting framework from object.
Returns:
Either scikit-learn framework or xgboost framework
"""
# detect framework type from model object
if "sklearn" in type(model_obj).__module__:
return SCIKIT_LEARN_FRAMEWORK_NAME
elif "xgboost" in type(model_obj).__module__:
return XGBOOST_FRAMEWORK_NAME
else:
error_msg = (
"Invalid model type detected: {}.{}. "
"Please make sure the model file is an exported sklearn model, "
"xgboost model or pipeline.").format(
type(model_obj).__module__,
type(model_obj).__name__)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
def _count_num_files_in_path(model_path, specified_file_names):
"""Count how many specified files exist in model_path.
Args:
model_path: The local path to the directory that contains the model file.
specified_file_names: The file names to be checked
Returns:
An integer indicating how many specified_file_names are found in model_path.
"""
num_matches = 0
for file_name in specified_file_names:
if os.path.exists(os.path.join(model_path, file_name)):
num_matches += 1
return num_matches
def detect_framework(model_path):
"""Detect framework from model_path by analyzing file extensions.
Args:
model_path: The local path to the directory that contains the model file.
Raises:
PredictionError: If framework can not be identified from model path.
Returns:
A string representing the identified framework or None (custom code is
assumed in this situation).
"""
num_tensorflow_models = _count_num_files_in_path(
model_path, TENSORFLOW_SPECIFIC_MODEL_FILE_NAMES)
num_xgboost_models = _count_num_files_in_path(
model_path, XGBOOST_SPECIFIC_MODEL_FILE_NAMES)
num_sklearn_models = _count_num_files_in_path(model_path,
SCIKIT_LEARN_MODEL_FILE_NAMES)
num_matches = num_tensorflow_models + num_xgboost_models + num_sklearn_models
if num_matches > 1:
error_msg = "Multiple model files are found in the model_path: {}".format(
model_path)
logging.critical(error_msg)
raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, error_msg)
if num_tensorflow_models == 1:
return TENSORFLOW_FRAMEWORK_NAME
elif num_xgboost_models == 1:
return XGBOOST_FRAMEWORK_NAME
elif num_sklearn_models == 1:
model_obj = load_joblib_or_pickle_model(model_path)
return detect_sk_xgb_framework_from_obj(model_obj)
else:
logging.warning(("Model files are not found in the model_path."
"Assumed to be custom code."))
return None
def get_field_in_version_json(field_name):
"""Gets the value of field_name in the version being created, if it exists.
Args:
field_name: Name of the key used for retrieving the corresponding value from
version json object.
Returns:
The value of the given field in the version object or the user provided create
version request if it exists. Otherwise None is returned.
"""
if not os.environ.get("create_version_request"):
return None
request = json.loads(os.environ.get("create_version_request"))
if not request or not isinstance(request, dict):
return None
version = request.get("version")
if not version or not isinstance(version, dict):
return None
logging.info("Found value: %s, for field: %s from create_version_request",
version.get(field_name), field_name)
return version.get(field_name)
def parse_predictions(response_json):
"""Parses the predictions from the json response from prediction server.
Args:
response_json(Text): The JSON formatted response to parse.
Returns:
Predictions from the response json.
Raises:
ValueError if response_json is malformed.
"""
if not isinstance(response_json, collections_lib.Mapping):
raise ValueError(
"Invalid response received from prediction server: {}".format(
repr(response_json)))
if PREDICTIONS_KEY not in response_json:
raise ValueError(
"Required field '{}' missing in prediction server response: {}".format(
PREDICTIONS_KEY, repr(response_json)))
return response_json.pop(PREDICTIONS_KEY)
def parse_outputs(response_json):
"""Parses the outputs from the json response from prediction server.
Args:
response_json(Text): The JSON formatted response to parse.
Returns:
Outputs from the response json.
Raises:
ValueError if response_json is malformed.
"""
if not isinstance(response_json, collections_lib.Mapping):
raise ValueError(
"Invalid response received from prediction server: {}".format(
repr(response_json)))
if OUTPUTS_KEY not in response_json:
raise ValueError(
"Required field '{}' missing in prediction server response: {}".format(
OUTPUTS_KEY, repr(response_json)))
return response_json.pop(OUTPUTS_KEY)
def parse_instances(request_json):
"""Parses instances from the json request sent to prediction server.
Args:
request_json(Text): The JSON formatted request to parse.
Returns:
Instances from the request json.
Raises:
ValueError if request_json is malformed.
"""
if not isinstance(request_json, collections_lib.Mapping):
raise ValueError("Invalid request sent to prediction server: {}".format(
repr(request_json)))
if INSTANCES_KEY not in request_json:
raise ValueError(
"Required field '{}' missing in prediction server request: {}".format(
INSTANCES_KEY, repr(request_json)))
return request_json.pop(INSTANCES_KEY)

View File

@@ -0,0 +1,22 @@
"""A py_binary running that's equivalent to running "python setup.py"."""
import sys
import setuptools
del setuptools # Ensure setuptools is available.
def main():
if len(sys.argv) <= 1:
raise RuntimeError("Must specify setup.py file as the first argument.")
with open(sys.argv[1], "r") as f:
setup_content = f.read()
# Simulates running "python setup.py" by removing the setup_runner from
# sys.argv[0].
sys.argv = sys.argv[1:]
exec(setup_content) # pylint: disable=exec-used
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Google Cloud Machine Learning SDK
"""Utils used internally in the Cloud ML."""
from . import _decoders # pylint: disable=relative-import

View File

@@ -0,0 +1,195 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataflow-related utilities.
"""
import csv
import json
import logging
class DecodeError(Exception):
"""Base decode error."""
pass
class PassthroughDecoder(object):
def decode(self, x):
return x
class JsonDecoder(object):
"""A decoder for JSON formatted data."""
def decode(self, x):
return json.loads(x)
class CsvDecoder(object):
"""A decoder for CSV formatted data.
"""
# TODO(user) Revisit using cStringIO for design compatibility with
# coders.CsvCoder.
class _LineGenerator(object):
"""A csv line generator that allows feeding lines to a csv.DictReader."""
def __init__(self):
self._lines = []
def push_line(self, line):
# This API currently supports only one line at a time.
assert not self._lines
self._lines.append(line)
def __iter__(self):
return self
def next(self):
# This API currently supports only one line at a time.
# If this ever supports more than one row be aware that DictReader might
# attempt to read more than one record if one of the records is empty line
line_length = len(self._lines)
if line_length == 0:
raise DecodeError(
'Columns do not match specified csv headers: empty line was found')
assert line_length == 1, 'Unexpected number of lines %s' % line_length
# This doesn't maintain insertion order to the list, which is fine
# because the list has only 1 element. If there were more and we wanted
# to maintain order and timecomplexity we would switch to deque.popleft.
return self._lines.pop()
class _ReaderWrapper(object):
"""A wrapper for csv.reader / csv.DictReader to make it picklable."""
def __init__(self, line_generator, column_names, delimiter, decode_to_dict,
skip_initial_space):
self._state = (line_generator, column_names, delimiter, decode_to_dict,
skip_initial_space)
self._line_generator = line_generator
if decode_to_dict:
self._reader = csv.DictReader(
line_generator, column_names, delimiter=str(delimiter),
skipinitialspace=skip_initial_space)
else:
self._reader = csv.reader(line_generator, delimiter=str(delimiter),
skipinitialspace=skip_initial_space)
def read_record(self, x):
self._line_generator.push_line(x)
return self._reader.next()
def __getstate__(self):
return self._state
def __setstate__(self, state):
self.__init__(*state)
def __init__(
self, column_names, numeric_column_names, delimiter, decode_to_dict,
fail_on_error, skip_initial_space):
"""Initializer.
Args:
column_names: Tuple of strings. Order must match the order in the file.
numeric_column_names: Tuple of strings. Contains column names that are
numeric. Every name in numeric_column_names must also be in
column_names.
delimiter: String used to separate fields.
decode_to_dict: Boolean indicating whether the docoder should generate a
dictionary instead of a raw sequence. True by default.
fail_on_error: Whether to fail if a corrupt row is found.
skip_initial_space: When True, whitespace immediately following the
delimiter is ignored.
"""
self._column_names = column_names
self._numeric_column_names = set(numeric_column_names)
self._reader = self._ReaderWrapper(
self._LineGenerator(), column_names, delimiter, decode_to_dict,
skip_initial_space)
self._decode_to_dict = decode_to_dict
self._fail_on_error = fail_on_error
def _handle_corrupt_row(self, message):
"""Handle corrupt rows.
Depending on whether the decoder is configured to fail on error it will
raise a DecodeError or return None.
Args:
message: String, the error message to raise.
Returns:
None, when the decoder is not configured to fail on error.
Raises:
DecodeError: when the decoder is configured to fail on error.
"""
if self._fail_on_error:
raise DecodeError(message)
else:
# TODO(user) Don't log every time but only every N invalid lines.
logging.warning('Discarding invalid row: %s', message)
return None
def _get_value(self, column_name, value):
# TODO(user) remove this logic from the decoders and let it be
# part of prepreocessing. CSV is a schema-less container we shouldn't be
# performing these conversions here.
if not value or not value.strip():
return None
if column_name in self._numeric_column_names:
return float(value)
return value
# Please run //third_party/py/google/cloud/ml:benchmark_coders_test
# if you make any changes on these methods.
def decode(self, record):
"""Decodes the given string.
Args:
record: String to be decoded.
Returns:
Serialized object corresponding to decoded string. Or None if there's an
error and the decoder is configured not to fail on error.
Raises:
DecodeError: If columns do not match specified csv headers.
ValueError: If some numeric column has non-numeric data.
"""
try:
record = self._reader.read_record(record)
except Exception as e: # pylint: disable=broad-except
return self._handle_corrupt_row('%s: %s' % (e, record))
# Check record length mismatches.
if len(record) != len(self._column_names):
return self._handle_corrupt_row(
'Columns do not match specified csv headers: %s -> %s' % (
self._column_names, record))
if self._decode_to_dict:
# DictReader fills missing colums with None. Thus, if the last value
# as defined by the schema is None, there was at least one "missing"
# column.
if record[self._column_names[-1]] is None:
return self._handle_corrupt_row(
'Columns do not match specified csv headers: %s -> %s' % (
self._column_names, record))
for name, value in record.iteritems():
record[name] = self._get_value(name, value)
else:
for index, name in enumerate(self._column_names):
value = record[index]
record[index] = self._get_value(name, value)
return record

View File

@@ -0,0 +1,55 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exceptions used when sending HTTP requests.
"""
import json
class _RequestException(Exception):
"""Exception returned by a request."""
def __init__(self, status, content):
super(_RequestException, self).__init__()
self.status = status
self.content = content
self.message = content
# Try extract a message from the body; swallow possible resulting
# ValueErrors and KeyErrors.
try:
self.message = json.loads(content)['error']['message']
except ValueError:
pass
except KeyError:
pass
except TypeError:
pass
def __str__(self):
return self.message
@property
def error_code(self):
"""Returns the error code if one is present and None otherwise."""
try:
parsed_content = json.loads(self.content)
except ValueError:
# Response isn't json.
# TODO(user): What if the response is html? We appear to get HTML
# responses if we hit a path that the server doesn't recognize.
# For example if you do project/operations/project/operations you
# get an HTML error with status code 404.
return None
return parsed_content.get('error', {}).get('code', None)

View File

@@ -0,0 +1,160 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some file reading utilities.
"""
import glob
import logging
import os
import shutil
import subprocess
import time
from apache_beam.io import gcsio
# TODO(user): Remove use of gsutil
def create_directory(path):
"""Creates a local directory.
Does nothing if a Google Cloud Storage path is passed.
Args:
path: the path to create.
Raises:
ValueError: if path is a file or os.makedir fails.
"""
if path.startswith('gs://'):
return
if os.path.isdir(path):
return
if os.path.isfile(path):
raise ValueError('Unable to create location. "%s" exists and is a file.' %
path)
try:
os.makedirs(path)
except: # pylint: disable=broad-except
raise ValueError('Unable to create location. "%s"' % path)
def open_local_or_gcs(path, mode):
"""Opens the given path."""
if path.startswith('gs://'):
try:
return gcsio.GcsIO().open(path, mode)
except Exception as e: # pylint: disable=broad-except
# Currently we retry exactly once, to work around flaky gcs calls.
logging.error('Retrying after exception reading gcs file: %s', e)
time.sleep(10)
return gcsio.GcsIO().open(path, mode)
else:
return open(path, mode)
def file_exists(path):
"""Returns whether the file exists."""
if path.startswith('gs://'):
return gcsio.GcsIO().exists(path)
else:
return os.path.exists(path)
def copy_file(src, dest):
"""Copy a file from src to dest.
Supports local and Google Cloud Storage.
Args:
src: source path.
dest: destination path.
"""
with open_local_or_gcs(src, 'r') as h_src:
with open_local_or_gcs(dest, 'w') as h_dest:
shutil.copyfileobj(h_src, h_dest)
def write_file(path, data):
"""Writes data to a file.
Supports local and Google Cloud Storage.
Args:
path: output file path.
data: data to write to file.
"""
with open_local_or_gcs(path, 'w') as h_dest:
h_dest.write(data) # pylint: disable=no-member
def load_file(path):
if path.startswith('gs://'):
content = _read_cloud_file(path)
else:
content = _read_local_file(path)
if content is None:
raise ValueError('File cannot be loaded from %s' % path)
return content
def glob_files(path):
if path.startswith('gs://'):
return gcsio.GcsIO().glob(path)
else:
return glob.glob(path)
def _read_local_file(local_path):
with open(local_path, 'r') as f:
return f.read()
def _read_cloud_file(storage_path):
with open_local_or_gcs(storage_path, 'r') as src:
return src.read()
def read_file_stream(file_list):
for path in file_list if not isinstance(file_list, basestring) else [
file_list
]:
if path.startswith('gs://'):
for line in _read_cloud_file_stream(path):
yield line
else:
for line in _read_local_file_stream(path):
yield line
def _read_local_file_stream(path):
with open(path) as file_in:
for line in file_in:
yield line
def _read_cloud_file_stream(path):
read_file = subprocess.Popen(
['gsutil', 'cp', path, '-'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with read_file.stdout as file_in:
for line in file_in:
yield line
if read_file.wait() != 0:
raise IOError('Unable to read data from Google Cloud Storage: "%s"' % path)

View File

@@ -0,0 +1,183 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Retry decorators for calls raising exceptions.
This module is used mostly to decorate all integration points where the code
makes calls to remote services. Searching through the code base for @retry
should find all such places. For this reason even places where retry is not
needed right now use a @retry.no_retries decorator.
"""
import logging
import random
import sys
import time
import traceback
from google.cloud.ml.util import _exceptions
from six import reraise
class FuzzedExponentialIntervals(object):
"""Iterable for intervals that are exponentially spaced, with fuzzing.
On iteration, yields retry interval lengths, in seconds. Every iteration over
this iterable will yield differently fuzzed interval lengths, as long as fuzz
is nonzero.
Args:
initial_delay_secs: The delay before the first retry, in seconds.
num_retries: The total number of times to retry.
factor: The exponential factor to use on subsequent retries.
Default is 2 (doubling).
fuzz: A value between 0 and 1, indicating the fraction of fuzz. For a
given delay d, the fuzzed delay is randomly chosen between
[(1 - fuzz) * d, d].
max_delay_sec: Maximum delay (in seconds). After this limit is reached,
further tries use max_delay_sec instead of exponentially increasing
the time. Defaults to 5 minutes.
"""
def __init__(self,
initial_delay_secs,
num_retries,
factor=2,
fuzz=0.5,
max_delay_secs=30):
self._initial_delay_secs = initial_delay_secs
self._num_retries = num_retries
self._factor = factor
if not 0 <= fuzz <= 1:
raise ValueError('Fuzz parameter expected to be in [0, 1] range.')
self._fuzz = fuzz
self._max_delay_secs = max_delay_secs
def __iter__(self):
current_delay_secs = min(self._max_delay_secs, self._initial_delay_secs)
for _ in range(self._num_retries):
fuzz_multiplier = 1 - self._fuzz + random.random() * self._fuzz
yield current_delay_secs * fuzz_multiplier
current_delay_secs = min(self._max_delay_secs,
current_delay_secs * self._factor)
def retry_on_server_errors_filter(exception):
"""Filter allowing retries on server errors and non-HttpErrors."""
if isinstance(exception, _exceptions._RequestException): # pylint: disable=protected-access
if exception.status >= 500:
return True
else:
return False
else:
# We may get here for non HttpErrors such as socket timeouts, SSL
# exceptions, etc.
return True
class Clock(object):
"""A simple clock implementing sleep()."""
def sleep(self, value):
time.sleep(value)
def no_retries(fun):
"""A retry decorator for places where we do not want retries."""
return with_exponential_backoff(retry_filter=lambda _: False, clock=None)(fun)
def with_exponential_backoff(num_retries=10,
initial_delay_secs=1,
logger=logging.warning,
retry_filter=retry_on_server_errors_filter,
clock=Clock(),
fuzz=True):
"""Decorator with arguments that control the retry logic.
Args:
num_retries: The total number of times to retry.
initial_delay_secs: The delay before the first retry, in seconds.
logger: A callable used to report en exception. Must have the same signature
as functions in the standard logging module. The default is
logging.warning.
retry_filter: A callable getting the exception raised and returning True
if the retry should happen. For instance we do not want to retry on
404 Http errors most of the time. The default value will return true
for server errors (HTTP status code >= 500) and non Http errors.
clock: A clock object implementing a sleep method. The default clock will
use time.sleep().
fuzz: True if the delay should be fuzzed (default). During testing False
can be used so that the delays are not randomized.
Returns:
As per Python decorators with arguments pattern returns a decorator
for the function which in turn will return the wrapped (decorated) function.
The decorator is intended to be used on callables that make HTTP or RPC
requests that can temporarily timeout or have transient errors. For instance
the make_http_request() call below will be retried 16 times with exponential
backoff and fuzzing of the delay interval (default settings).
from cloudml.util import retry
# ...
@retry.with_exponential_backoff()
make_http_request(args)
"""
def real_decorator(fun):
"""The real decorator whose purpose is to return the wrapped function."""
retry_intervals = iter(
FuzzedExponentialIntervals(
initial_delay_secs, num_retries, fuzz=0.5 if fuzz else 0))
def wrapper(*args, **kwargs):
while True:
try:
return fun(*args, **kwargs)
except Exception as exn: # pylint: disable=broad-except
if not retry_filter(exn):
raise
# Get the traceback object for the current exception. The
# sys.exc_info() function returns a tuple with three elements:
# exception type, exception value, and exception traceback.
exn_traceback = sys.exc_info()[2]
try:
try:
sleep_interval = next(retry_intervals)
except StopIteration:
# Re-raise the original exception since we finished the retries.
reraise(type(exn), exn, exn_traceback)
logger(
'Retry with exponential backoff: waiting for %s seconds before '
'retrying %s because we caught exception: %s '
'Traceback for above exception (most recent call last):\n%s',
sleep_interval,
getattr(fun, '__name__', str(fun)),
''.join(traceback.format_exception_only(exn.__class__, exn)),
''.join(traceback.format_tb(exn_traceback)))
clock.sleep(sleep_interval)
finally:
# Traceback objects in locals can cause reference cycles that will
# prevent garbage collection. Clear it now since we do not need
# it anymore.
if sys.version_info < (3, 0): # only for py 2
sys.exc_clear()
exn_traceback = None
return wrapper
return real_decorator

View File

@@ -0,0 +1,45 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud ML Engine Prediction version constants.
"""
# Google Cloud Machine Learning Prediction Engine Version info.
__version__ = '0.1-alpha'
required_install_packages = [
'oauth2client >= 2.2.0',
'six >= 1.10.0, < 2.0',
'bs4 >= 0.0.1, < 1.0',
'numpy >= 1.10.4', # Don't pin numpy, as it requires a recompile.
'crcmod >= 1.7, < 2.0',
'nltk >= 3.2.1, <= 3.4',
'pyyaml >= 3.11, < 7.0',
'protobuf >= 3.1.0, < 4.0',
# isort is avro dependency which picks the latest.
# We do not want use latest because of b/160639883.
'isort < 5.0',
# Python 3.7 seems incompatible with enum34. See b/148202403.
'enum34 >= 1.1; python_version <= "3.5"',
]
required_install_packages_with_batch_prediction = required_install_packages + [
# Remove < 2.4.0 after b/77730826 is fixed.
'apache-beam[gcp] >= 2.0.0, < 2.4.0',
'google-cloud-logging >= 0.23.0, < 1.0',
]
required_install_packages_no_deps = required_install_packages + [
'google-cloud-logging >= 0.23.0, <=1.15.0',
'google-api-python-client <= 1.9.0',
]