# -*- coding: utf-8 -*- # # Copyright 2019 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ai-platform local predict command.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from googlecloudsdk.calliope import base from googlecloudsdk.command_lib.ml_engine import flags from googlecloudsdk.command_lib.ml_engine import local_utils from googlecloudsdk.command_lib.ml_engine import predict_utilities from googlecloudsdk.core import log def _AddLocalPredictArgs(parser): """Add arguments for `gcloud ai-platform local predict` command.""" parser.add_argument('--model-dir', required=True, help='Path to the model.') flags.FRAMEWORK_MAPPER.choice_arg.AddToParser(parser) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( '--json-request', help="""\ Path to a local file containing the body of JSON request. An example of a JSON request: { "instances": [ {"x": [1, 2], "y": [3, 4]}, {"x": [-1, -2], "y": [-3, -4]} ] } This flag accepts "-" for stdin. """) group.add_argument( '--json-instances', help="""\ Path to a local file from which instances are read. Instances are in JSON format; newline delimited. An example of the JSON instances file: {"images": [0.0, ..., 0.1], "key": 3} {"images": [0.0, ..., 0.1], "key": 2} ... This flag accepts "-" for stdin. """) group.add_argument( '--text-instances', help="""\ Path to a local file from which instances are read. Instances are in UTF-8 encoded text format; newline delimited. An example of the text instances file: 107,4.9,2.5,4.5,1.7 100,5.7,2.8,4.1,1.3 ... This flag accepts "-" for stdin. """) flags.SIGNATURE_NAME.AddToParser(parser) class Predict(base.Command): """Run prediction locally.""" @staticmethod def Args(parser): _AddLocalPredictArgs(parser) def Run(self, args): framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework) framework_flag = framework.name.lower() if framework else 'tensorflow' if args.signature_name is None: log.status.Print('If the signature defined in the model is ' 'not serving_default then you must specify it via ' '--signature-name flag, otherwise the command may fail.') results = local_utils.RunPredict( args.model_dir, json_request=args.json_request, json_instances=args.json_instances, text_instances=args.text_instances, framework=framework_flag, signature_name=args.signature_name) if not args.IsSpecified('format'): # default format is based on the response. if isinstance(results, list): predictions = results else: predictions = results.get('predictions') args.format = predict_utilities.GetDefaultFormat(predictions) return results _DETAILED_HELP = { 'DESCRIPTION': """\ *{command}* performs prediction locally with the given instances. It requires the [TensorFlow SDK](https://www.tensorflow.org/install) be installed locally. The output format mirrors `gcloud ai-platform predict` (online prediction). You cannot use this command with custom prediction routines. """ } Predict.detailed_help = _DETAILED_HELP