123 lines
4.4 KiB
Python
123 lines
4.4 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2016 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for dealing with ML predict API."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
import json
|
|
|
|
from googlecloudsdk.core import exceptions as core_exceptions
|
|
from googlecloudsdk.core.credentials import requests
|
|
|
|
from six.moves import http_client as httplib
|
|
|
|
|
|
class InstancesEncodeError(core_exceptions.Error):
|
|
"""Indicates that error occurs while decoding the instances in http body."""
|
|
pass
|
|
|
|
|
|
class HttpRequestFailError(core_exceptions.Error):
|
|
"""Indicates that the http request fails in some way."""
|
|
pass
|
|
|
|
|
|
def _GetPrediction(url, body, headers):
|
|
"""Make http request to get prediction results."""
|
|
response = requests.GetSession().request(
|
|
'POST', url, data=body, headers=headers)
|
|
return response.status_code, response.text
|
|
|
|
|
|
def Predict(model_or_version_ref, instances, signature_name=None):
|
|
"""Performs online prediction on the input data file.
|
|
|
|
Args:
|
|
model_or_version_ref: a Resource representing either a model or a version.
|
|
instances: a list of JSON or UTF-8 encoded instances to perform
|
|
prediction on.
|
|
signature_name: name of input/output signature in the TF meta graph.
|
|
|
|
Returns:
|
|
A json object that contains predictions.
|
|
|
|
Raises:
|
|
HttpRequestFailError: if error happens with http request, or parsing
|
|
the http response.
|
|
"""
|
|
url = model_or_version_ref.SelfLink() + ':predict'
|
|
# Construct the body for the predict request.
|
|
headers = {'Content-Type': 'application/json'}
|
|
content = {'instances': instances}
|
|
if signature_name:
|
|
content['signature_name'] = signature_name
|
|
try:
|
|
body = json.dumps(content, sort_keys=True)
|
|
except (UnicodeDecodeError, TypeError):
|
|
# Python 2: UnicodeDecode Error, Python 3: TypeError
|
|
raise InstancesEncodeError('Instances cannot be JSON encoded, probably '
|
|
'because the input is not utf-8 encoded.')
|
|
|
|
# Workaround since gcloud cannot handle HttpBody properly, see b/31403673
|
|
response_status, response_body = _GetPrediction(url, body, headers)
|
|
if int(response_status) != httplib.OK:
|
|
raise HttpRequestFailError('HTTP request failed. Response: ' +
|
|
response_body)
|
|
try:
|
|
return json.loads(response_body)
|
|
except ValueError:
|
|
raise HttpRequestFailError('No JSON object could be decoded from the '
|
|
'HTTP response body: ' + response_body)
|
|
|
|
|
|
def Explain(model_or_version_ref, instances):
|
|
"""Performs online explanation on the input data file.
|
|
|
|
Args:
|
|
model_or_version_ref: a Resource representing either a model or a version.
|
|
instances: a list of JSON or UTF-8 encoded instances to perform
|
|
prediction on.
|
|
|
|
Returns:
|
|
A json object that contains explanations.
|
|
|
|
Raises:
|
|
HttpRequestFailError: if error happens with http request, or parsing
|
|
the http response.
|
|
"""
|
|
url = model_or_version_ref.SelfLink() + ':explain'
|
|
# Construct the body for the explain request.
|
|
headers = {'Content-Type': 'application/json'}
|
|
content = {'instances': instances}
|
|
try:
|
|
body = json.dumps(content, sort_keys=True)
|
|
except (UnicodeDecodeError, TypeError):
|
|
# Python 2: UnicodeDecode Error, Python 3: TypeError
|
|
raise InstancesEncodeError('Instances cannot be JSON encoded, probably '
|
|
'because the input is not utf-8 encoded.')
|
|
|
|
# Workaround since gcloud cannot handle HttpBody properly, see b/31403673
|
|
response_status, response_body = _GetPrediction(url, body, headers)
|
|
if int(response_status) != httplib.OK:
|
|
raise HttpRequestFailError('HTTP request failed. Response: ' +
|
|
response_body)
|
|
try:
|
|
return json.loads(response_body)
|
|
except ValueError:
|
|
raise HttpRequestFailError('No JSON object could be decoded from the '
|
|
'HTTP response body: ' + response_body)
|