372 lines
12 KiB
Python
372 lines
12 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2018 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Common utilities for the gcloud export/import commands."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
import re
|
|
import textwrap
|
|
|
|
from apitools.base.protorpclite import message_types
|
|
from apitools.base.protorpclite import messages
|
|
from apitools.base.py import encoding as api_encoding
|
|
from apitools.base.py import encoding_helper
|
|
from googlecloudsdk.api_lib.dataproc import exceptions
|
|
from googlecloudsdk.core import log
|
|
from googlecloudsdk.core import yaml
|
|
from googlecloudsdk.core import yaml_validator
|
|
from googlecloudsdk.core.util import encoding
|
|
|
|
|
|
def AddExportFlags(parser, schema_path=None):
|
|
"""Add common export flags to the arg parser.
|
|
|
|
Args:
|
|
parser: The argparse parser object.
|
|
schema_path: The resource instance schema file path if there is one.
|
|
"""
|
|
|
|
help_text = """Path to a YAML file where the configuration will be exported.
|
|
Alternatively, you may omit this flag to write to standard output."""
|
|
if schema_path is not None:
|
|
help_text += """ For a schema describing the export/import format, see:
|
|
{}.
|
|
""".format(schema_path)
|
|
parser.add_argument(
|
|
'--destination',
|
|
help=textwrap.dedent(help_text),
|
|
# Allow writing to stdout.
|
|
required=False)
|
|
|
|
|
|
def AddImportFlags(parser, schema_path=None):
|
|
"""Add common import flags to the arg parser.
|
|
|
|
Args:
|
|
parser: The argparse parser object.
|
|
schema_path: The resource instance schema file path if there is one.
|
|
"""
|
|
|
|
help_text = """Path to a YAML file containing configuration export data.
|
|
Alternatively, you may omit this flag to read from standard input."""
|
|
if schema_path is not None:
|
|
help_text += """For a schema describing the export/import format, see:
|
|
{}.
|
|
""".format(schema_path)
|
|
if '$CLOUDSDKROOT' in schema_path:
|
|
help_text += """
|
|
|
|
Note: $CLOUDSDKROOT represents the Google Cloud CLI's installation directory.
|
|
"""
|
|
|
|
parser.add_argument(
|
|
'--source',
|
|
help=textwrap.dedent(help_text),
|
|
# Allow reading from stdin.
|
|
required=False)
|
|
|
|
|
|
def GetSchemaPath(api_name, api_version='v1', message_name=None,
|
|
for_help=False):
|
|
"""Returns the schema installation path.
|
|
|
|
$CLOUDSDKROOT/lib/googlecloudsdk/schemas/
|
|
{api}/{api_version}/{message_name}.yaml
|
|
|
|
Args:
|
|
api_name: The api name.
|
|
api_version: The API version string.
|
|
message_name: The UpperCamelCase message name.
|
|
for_help: Replaces the actual Cloud SDK installation root dir with
|
|
$CLOUDSDKROOT.
|
|
"""
|
|
path = os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(
|
|
encoding.Decode(__file__)))),
|
|
'schemas',
|
|
api_name,
|
|
api_version,
|
|
'{}.yaml'.format(message_name),
|
|
)
|
|
if for_help:
|
|
rel_path_index = path.rfind(os.path.sep + 'googlecloudsdk' + os.path.sep)
|
|
if rel_path_index < 0:
|
|
return path
|
|
path = os.path.join('$CLOUDSDKROOT', 'lib', path[rel_path_index + 1:])
|
|
return path
|
|
|
|
|
|
def ValidateYAML(parsed_yaml, schema_path):
|
|
"""Validates YAML against JSON schema.
|
|
|
|
Args:
|
|
parsed_yaml: YAML to validate
|
|
schema_path: JSON schema file path.
|
|
|
|
Raises:
|
|
IOError: if schema not found in installed resources.
|
|
files.Error: if schema file not found.
|
|
ValidationError: if the template doesn't obey the schema.
|
|
SchemaError: if the schema is invalid.
|
|
"""
|
|
yaml_validator.Validator(schema_path).Validate(parsed_yaml)
|
|
|
|
|
|
def _ParseProperties(error_message):
|
|
"""Parses disallowed properties from an error message.
|
|
|
|
Args:
|
|
error_message: The error message to parse.
|
|
|
|
Returns:
|
|
A list of property names.
|
|
|
|
A sample error message might look like this:
|
|
|
|
Additional properties are not allowed ('id', 'createTime', 'updateTime',
|
|
'name' were unexpected)
|
|
|
|
"""
|
|
return list(
|
|
property.strip('\'') for property in re.findall("'[^']*'", error_message))
|
|
|
|
|
|
def _ClearFields(fields, path_deque, py_dict):
|
|
"""Clear the given fields in a dict at a given path.
|
|
|
|
Args:
|
|
fields: A list of fields to clear
|
|
path_deque: A deque containing path segments
|
|
py_dict: A nested dict from which to clear the fields
|
|
"""
|
|
tmp_dict = py_dict
|
|
for elem in path_deque:
|
|
tmp_dict = tmp_dict[elem]
|
|
for field in fields:
|
|
if field in tmp_dict:
|
|
del tmp_dict[field]
|
|
|
|
|
|
def _IsDisallowedPropertiesError(error):
|
|
"""Checks if an error is due to properties that were not in the schema.
|
|
|
|
Args:
|
|
error: A ValidationError
|
|
|
|
Returns:
|
|
Whether the error was due to disallowed properties
|
|
"""
|
|
prop_validator = 'additionalProperties'
|
|
prop_message = 'Additional properties are not allowed'
|
|
return error.validator == prop_validator and prop_message in error.message
|
|
|
|
|
|
def _FilterYAML(parsed_yaml, schema_path):
|
|
"""Filter out fields from the yaml that are not in the schema.
|
|
|
|
Args:
|
|
parsed_yaml: yaml to filter
|
|
schema_path: Path to schema.
|
|
"""
|
|
has_warnings = False
|
|
for error in yaml_validator.Validator(schema_path).Iterate(parsed_yaml):
|
|
# There are other types of errors (for example, missing a required field),
|
|
# but these are the only ones we expect to see on export and the only ones
|
|
# we want to act on. There is no way to distinguish disallowed fields from
|
|
# unrecognized fields. If we attempt to export an unrecognized value for a
|
|
# recognized field (this will happen whenever we add a new enum value), or
|
|
# if we attempt to export a resource that is missing a required field, we
|
|
# will log the errors as warnings and the exported data will not be able to
|
|
# be imported via the import command until the import command is updated.
|
|
if _IsDisallowedPropertiesError(error):
|
|
fields_to_remove = _ParseProperties(error.message)
|
|
_ClearFields(fields_to_remove, error.path, parsed_yaml)
|
|
else:
|
|
log.warning(error.message)
|
|
has_warnings = True
|
|
if has_warnings:
|
|
log.warning('The import command may need to be updated to handle '
|
|
'the export data.')
|
|
|
|
|
|
def Import(message_type, stream, schema_path=None):
|
|
"""Reads YAML from a stream as a message.
|
|
|
|
Args:
|
|
message_type: Type of message to load YAML into.
|
|
stream: Input stream or buffer containing the YAML.
|
|
schema_path: JSON schema file path. None for no YAML validation.
|
|
|
|
Raises:
|
|
ParseError: if yaml could not be parsed as the given message type.
|
|
|
|
Returns:
|
|
message_type object.
|
|
"""
|
|
parsed_yaml = yaml.load(stream)
|
|
if schema_path:
|
|
# If a schema is provided, validate against it.
|
|
yaml_validator.Validator(schema_path).Validate(parsed_yaml)
|
|
try:
|
|
message = api_encoding.PyValueToMessage(message_type, parsed_yaml)
|
|
except Exception as e:
|
|
raise exceptions.ParseError('Cannot parse YAML: [{0}]'.format(e))
|
|
return message
|
|
|
|
|
|
# pylint: disable=protected-access
|
|
# TODO(b/177577343)
|
|
# This is a terrible hack to fix an apitools issue that makes all registered
|
|
# codecs global, which breaks our presubmits. This will be removed once
|
|
# declarative workflows deprecate our current import/export tooling.
|
|
class _ProtoJsonApiTools(encoding_helper._ProtoJsonApiTools):
|
|
"""JSON encoder used by apitools clients."""
|
|
_INSTANCE = None
|
|
|
|
@classmethod
|
|
def Get(cls):
|
|
if cls._INSTANCE is None:
|
|
cls._INSTANCE = cls()
|
|
return cls._INSTANCE
|
|
|
|
def encode_message(self, message):
|
|
if isinstance(message, messages.FieldList):
|
|
return '[%s]' % (', '.join(self.encode_message(x) for x in message))
|
|
|
|
# pylint: disable=unidiomatic-typecheck
|
|
if type(message) in encoding_helper._CUSTOM_MESSAGE_CODECS:
|
|
return encoding_helper._CUSTOM_MESSAGE_CODECS[type(message)].encoder(
|
|
message)
|
|
|
|
message = _EncodeUnknownFields(message)
|
|
result = super(encoding_helper._ProtoJsonApiTools,
|
|
self).encode_message(message)
|
|
result = _EncodeCustomFieldNames(message, result)
|
|
return json.dumps(json.loads(result), sort_keys=True)
|
|
|
|
def encode_field(self, field, value):
|
|
for encoder in _GetFieldCodecs(field, 'encoder'):
|
|
result = encoder(field, value)
|
|
value = result.value
|
|
if result.complete:
|
|
return value
|
|
if isinstance(field, messages.EnumField):
|
|
if field.repeated:
|
|
remapped_value = [
|
|
encoding_helper.GetCustomJsonEnumMapping(
|
|
field.type, python_name=e.name) or e.name for e in value
|
|
]
|
|
else:
|
|
remapped_value = encoding_helper.GetCustomJsonEnumMapping(
|
|
field.type, python_name=value.name)
|
|
if remapped_value:
|
|
return remapped_value
|
|
if (isinstance(field, messages.MessageField) and
|
|
not isinstance(field, message_types.DateTimeField)):
|
|
value = json.loads(self.encode_message(value))
|
|
return super(encoding_helper._ProtoJsonApiTools,
|
|
self).encode_field(field, value)
|
|
|
|
|
|
def RegisterCustomFieldTypeCodecs(field_type_codecs):
|
|
"""Registers custom field codec for int64s."""
|
|
def _EncodeInt64Field(unused_field, value):
|
|
int_value = api_encoding.CodecResult(value=value, complete=True)
|
|
return int_value
|
|
|
|
def _DecodeInt64Field(unused_field, value):
|
|
# Don't need to do anything special, they're decoded just fine
|
|
return api_encoding.CodecResult(value=value, complete=True)
|
|
|
|
field_type_codecs[messages.IntegerField] = encoding_helper._Codec(
|
|
encoder=_EncodeInt64Field, decoder=_DecodeInt64Field)
|
|
return field_type_codecs
|
|
|
|
|
|
def _GetFieldCodecs(field, attr):
|
|
custom_field_codecs = copy.deepcopy(encoding_helper._CUSTOM_FIELD_CODECS)
|
|
field_type_codecs = RegisterCustomFieldTypeCodecs(
|
|
copy.deepcopy(encoding_helper._FIELD_TYPE_CODECS))
|
|
|
|
result = [
|
|
getattr(custom_field_codecs.get(field), attr, None),
|
|
getattr(field_type_codecs.get(type(field)), attr, None),
|
|
]
|
|
return [x for x in result if x is not None]
|
|
|
|
|
|
def _EncodeUnknownFields(message):
|
|
"""Remap unknown fields in message out of message.source."""
|
|
source = encoding_helper._UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
|
|
if source is None:
|
|
return message
|
|
# CopyProtoMessage uses _ProtoJsonApiTools, which uses this message. Use
|
|
# the vanilla protojson-based copy function to avoid infinite recursion.
|
|
result = encoding_helper._CopyProtoMessageVanillaProtoJson(message)
|
|
pairs_field = message.field_by_name(source)
|
|
if not isinstance(pairs_field, messages.MessageField):
|
|
raise exceptions.InvalidUserInputError('Invalid pairs field %s' %
|
|
pairs_field)
|
|
pairs_type = pairs_field.message_type
|
|
value_field = pairs_type.field_by_name('value')
|
|
value_variant = value_field.variant
|
|
pairs = getattr(message, source)
|
|
codec = _ProtoJsonApiTools.Get()
|
|
for pair in pairs:
|
|
encoded_value = codec.encode_field(value_field, pair.value)
|
|
result.set_unrecognized_field(pair.key, encoded_value, value_variant)
|
|
setattr(result, source, [])
|
|
return result
|
|
|
|
|
|
def _EncodeCustomFieldNames(message, encoded_value):
|
|
field_remappings = list(
|
|
encoding_helper._JSON_FIELD_MAPPINGS.get(type(message), {}).items())
|
|
if field_remappings:
|
|
decoded_value = json.loads(encoded_value)
|
|
for python_name, json_name in field_remappings:
|
|
if python_name in encoded_value:
|
|
decoded_value[json_name] = decoded_value.pop(python_name)
|
|
encoded_value = json.dumps(decoded_value)
|
|
return encoded_value
|
|
|
|
|
|
def Export(message, stream=None, schema_path=None):
|
|
"""Writes a message as YAML to a stream.
|
|
|
|
Args:
|
|
message: Message to write.
|
|
stream: Output stream, None for writing to a string and returning it.
|
|
schema_path: JSON schema file path. If None then all message fields are
|
|
written, otherwise only fields in the schema are written.
|
|
|
|
Returns:
|
|
Returns the return value of yaml.dump(). If stream is None then the return
|
|
value is the YAML data as a string.
|
|
"""
|
|
|
|
result = _ProtoJsonApiTools.Get().encode_message(message)
|
|
message_dict = json.loads(
|
|
encoding_helper._IncludeFields(result, message, None))
|
|
if schema_path:
|
|
_FilterYAML(message_dict, schema_path)
|
|
return yaml.dump(message_dict, stream=stream)
|