532 lines
18 KiB
Python
532 lines
18 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2022 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Speech-to-text V2 client."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
import contextlib
|
|
import os
|
|
|
|
from apitools.base.py import list_pager
|
|
from googlecloudsdk.api_lib.storage import storage_util
|
|
from googlecloudsdk.api_lib.util import apis
|
|
from googlecloudsdk.api_lib.util import exceptions
|
|
from googlecloudsdk.api_lib.util import waiter
|
|
from googlecloudsdk.command_lib.ml.speech import flag_validations
|
|
from googlecloudsdk.core import properties
|
|
from googlecloudsdk.core import resources
|
|
from googlecloudsdk.core.util import files
|
|
from six.moves import urllib
|
|
|
|
|
|
_API_NAME = 'speech'
|
|
_API_VERSION = 'v2'
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _OverrideEndpoint(override):
|
|
"""Context manager to override an API's endpoint overrides for a while."""
|
|
endpoint_property = getattr(
|
|
properties.VALUES.api_endpoint_overrides, _API_NAME
|
|
)
|
|
old_endpoint = endpoint_property.Get()
|
|
try:
|
|
endpoint_property.Set(override)
|
|
yield
|
|
finally:
|
|
endpoint_property.Set(old_endpoint)
|
|
|
|
|
|
class SpeechV2Client(object):
|
|
"""Speech V2 API client wrappers."""
|
|
|
|
def __init__(self):
|
|
client_class = apis.GetClientClass(_API_NAME, _API_VERSION)
|
|
self._net_loc = urllib.parse.urlsplit(client_class.BASE_URL).netloc
|
|
messages = apis.GetMessagesModule(_API_NAME, _API_VERSION)
|
|
|
|
self._resource_parser = resources.Registry()
|
|
self._resource_parser.RegisterApiByName(_API_NAME, _API_VERSION)
|
|
self._encoding_to_message = {
|
|
'LINEAR16': (
|
|
messages.ExplicitDecodingConfig.EncodingValueValuesEnum.LINEAR16
|
|
),
|
|
'MULAW': messages.ExplicitDecodingConfig.EncodingValueValuesEnum.MULAW,
|
|
'ALAW': messages.ExplicitDecodingConfig.EncodingValueValuesEnum.ALAW,
|
|
}
|
|
self._messages = messages
|
|
|
|
def _GetClientForLocation(self, location):
|
|
with _OverrideEndpoint('https://{}-{}/'.format(location, self._net_loc)):
|
|
return apis.GetClientInstance(_API_NAME, _API_VERSION)
|
|
|
|
def _RecognizerServiceForLocation(self, location):
|
|
return self._GetClientForLocation(location).projects_locations_recognizers
|
|
|
|
def _OperationsServiceForLocation(self, location):
|
|
return self._GetClientForLocation(location).projects_locations_operations
|
|
|
|
def _LocationsServiceForLocation(self, location):
|
|
return self._GetClientForLocation(location).projects_locations
|
|
|
|
def CreateRecognizer(
|
|
self,
|
|
resource,
|
|
display_name,
|
|
model,
|
|
language_codes,
|
|
recognition_config,
|
|
):
|
|
"""Call API CreateRecognizer method with provided arguments."""
|
|
recognizer = self._messages.Recognizer(displayName=display_name)
|
|
|
|
recognizer.model = model
|
|
recognizer.languageCodes = language_codes
|
|
|
|
recognizer.defaultRecognitionConfig = recognition_config
|
|
|
|
request = self._messages.SpeechProjectsLocationsRecognizersCreateRequest(
|
|
parent=resource.Parent(
|
|
parent_collection='speech.projects.locations'
|
|
).RelativeName(),
|
|
recognizerId=resource.Name(),
|
|
recognizer=recognizer,
|
|
)
|
|
return self._RecognizerServiceForLocation(
|
|
location=resource.Parent().Name()
|
|
).Create(request)
|
|
|
|
def GetRecognizer(self, resource):
|
|
request = self._messages.SpeechProjectsLocationsRecognizersGetRequest(
|
|
name=resource.RelativeName()
|
|
)
|
|
return self._RecognizerServiceForLocation(
|
|
location=resource.Parent().Name()
|
|
).Get(request)
|
|
|
|
def DeleteRecognizer(self, resource):
|
|
request = self._messages.SpeechProjectsLocationsRecognizersDeleteRequest(
|
|
name=resource.RelativeName()
|
|
)
|
|
return self._RecognizerServiceForLocation(
|
|
location=resource.Parent().Name()
|
|
).Delete(request)
|
|
|
|
def ListRecognizers(self, location_resource, limit=None, page_size=None):
|
|
request = self._messages.SpeechProjectsLocationsRecognizersListRequest(
|
|
parent=location_resource.RelativeName()
|
|
)
|
|
if page_size:
|
|
request.page_size = page_size
|
|
return list_pager.YieldFromList(
|
|
self._RecognizerServiceForLocation(location_resource.Name()),
|
|
request,
|
|
limit=limit,
|
|
batch_size_attribute='pageSize',
|
|
batch_size=page_size,
|
|
field='recognizers',
|
|
)
|
|
|
|
def UpdateRecognizer(
|
|
self,
|
|
resource,
|
|
display_name,
|
|
model,
|
|
language_codes,
|
|
recognition_config,
|
|
update_mask,
|
|
):
|
|
"""Call API UpdateRecognizer method with provided arguments."""
|
|
recognizer = self._messages.Recognizer()
|
|
|
|
if display_name is not None:
|
|
recognizer.displayName = display_name
|
|
update_mask.append('display_name')
|
|
if model is not None:
|
|
recognizer.model = model
|
|
update_mask.append('model')
|
|
if language_codes is not None:
|
|
recognizer.languageCodes = language_codes
|
|
update_mask.append('language_codes')
|
|
|
|
recognizer.defaultRecognitionConfig = recognition_config
|
|
|
|
request = self._messages.SpeechProjectsLocationsRecognizersPatchRequest(
|
|
name=resource.RelativeName(),
|
|
recognizer=recognizer,
|
|
updateMask=','.join(update_mask),
|
|
)
|
|
return self._RecognizerServiceForLocation(
|
|
location=resource.Parent().Name()
|
|
).Patch(request)
|
|
|
|
def RunShort(
|
|
self,
|
|
resource,
|
|
audio,
|
|
recognition_config,
|
|
update_mask,
|
|
):
|
|
"""Call API Recognize method with provided arguments."""
|
|
recognize_req = self._messages.RecognizeRequest()
|
|
if os.path.isfile(audio):
|
|
recognize_req.content = files.ReadBinaryFileContents(audio)
|
|
elif storage_util.ObjectReference.IsStorageUrl(audio):
|
|
recognize_req.uri = audio
|
|
|
|
recognizer_service = self._RecognizerServiceForLocation(
|
|
location=resource.Parent().Name()
|
|
)
|
|
|
|
recognize_req.config = recognition_config
|
|
|
|
recognize_req.configMask = ','.join(update_mask)
|
|
|
|
request = self._messages.SpeechProjectsLocationsRecognizersRecognizeRequest(
|
|
recognizeRequest=recognize_req,
|
|
recognizer=resource.RelativeName(),
|
|
)
|
|
return recognizer_service.Recognize(request)
|
|
|
|
def RunBatch(
|
|
self,
|
|
resource,
|
|
audio,
|
|
recognition_config,
|
|
update_mask,
|
|
):
|
|
"""Call API Recognize method with provided arguments in batch mode."""
|
|
batch_audio_metadata = self._messages.BatchRecognizeFileMetadata(uri=audio)
|
|
recognize_req = self._messages.BatchRecognizeRequest(
|
|
recognizer=resource.RelativeName(),
|
|
files=[batch_audio_metadata],
|
|
)
|
|
|
|
recognizer_service = self._RecognizerServiceForLocation(
|
|
location=resource.Parent().Name()
|
|
)
|
|
|
|
recognize_req.config = recognition_config
|
|
|
|
recognize_req.recognitionOutputConfig = (
|
|
self._messages.RecognitionOutputConfig(
|
|
inlineResponseConfig=self._messages.InlineOutputConfig()
|
|
)
|
|
)
|
|
|
|
recognize_req.configMask = ','.join(update_mask)
|
|
|
|
return recognizer_service.BatchRecognize(recognize_req)
|
|
|
|
def GetOperationRef(self, operation):
|
|
"""Converts an Operation to a Resource."""
|
|
return self._resource_parser.ParseRelativeName(
|
|
operation.name, 'speech.projects.locations.operations'
|
|
)
|
|
|
|
def WaitForRecognizerOperation(self, location, operation_ref, message):
|
|
"""Waits for a Recognizer operation to complete.
|
|
|
|
Polls the Speech Operation service until the operation completes, fails, or
|
|
max_wait_ms elapses.
|
|
|
|
Args:
|
|
location: The location of the resource.
|
|
operation_ref: A Resource created by GetOperationRef describing the
|
|
Operation.
|
|
message: The message to display to the user while they wait.
|
|
|
|
Returns:
|
|
An Endpoint entity.
|
|
"""
|
|
poller = waiter.CloudOperationPoller(
|
|
result_service=self._RecognizerServiceForLocation(location),
|
|
operation_service=self._OperationsServiceForLocation(location),
|
|
)
|
|
|
|
return waiter.WaitFor(
|
|
poller=poller,
|
|
operation_ref=operation_ref,
|
|
message=message,
|
|
pre_start_sleep_ms=100,
|
|
max_wait_ms=20000,
|
|
)
|
|
|
|
def WaitForBatchRecognizeOperation(self, location, operation_ref, message):
|
|
"""Waits for a Batch Recognize operation to complete.
|
|
|
|
Polls the Speech Operation service until the operation completes, fails, or
|
|
max_wait_ms elapses.
|
|
|
|
Args:
|
|
location: The location of the resource.
|
|
operation_ref: A Resource created by GetOperationRef describing the
|
|
Operation.
|
|
message: The message to display to the user while they wait.
|
|
|
|
Returns:
|
|
An Endpoint entity.
|
|
"""
|
|
poller = waiter.CloudOperationPollerNoResources(
|
|
self._OperationsServiceForLocation(location),
|
|
lambda x: x,
|
|
)
|
|
|
|
return waiter.WaitFor(
|
|
poller,
|
|
operation_ref,
|
|
message=message,
|
|
wait_ceiling_ms=86400000,
|
|
)
|
|
|
|
def GetLocation(self, location_resource):
|
|
request = self._messages.SpeechProjectsLocationsGetRequest(
|
|
name=location_resource.RelativeName()
|
|
)
|
|
return self._LocationsServiceForLocation(
|
|
location=location_resource.Name()
|
|
).Get(request)
|
|
|
|
def ListLocations(self, filter_str=None, limit=None, page_size=None):
|
|
request = self._messages.SpeechProjectsLocationsListRequest(
|
|
name=properties.VALUES.core.project.Get()
|
|
)
|
|
if filter_str:
|
|
request.filter = filter_str
|
|
if page_size:
|
|
request.page_size = page_size
|
|
return list_pager.YieldFromList(
|
|
self._LocationsServiceForLocation('global'),
|
|
request,
|
|
limit=limit,
|
|
batch_size_attribute='pageSize',
|
|
batch_size=page_size,
|
|
field='locations',
|
|
)
|
|
|
|
def InitializeRecognitionConfig(
|
|
self, model=None, language_codes=None, update_mask=None
|
|
):
|
|
"""creates a recognition config object and initializes it with model and language codes."""
|
|
recognition_config = self._messages.RecognitionConfig()
|
|
if model is not None:
|
|
recognition_config.model = model
|
|
if language_codes is not None:
|
|
recognition_config.languageCodes = language_codes
|
|
|
|
if update_mask is None:
|
|
return recognition_config, update_mask
|
|
|
|
if model is not None:
|
|
update_mask.append('model')
|
|
if language_codes is not None:
|
|
update_mask.append('language_codes')
|
|
return recognition_config, update_mask
|
|
|
|
def InitializeDecodingConfigFromArgs(
|
|
self,
|
|
recognition_config,
|
|
args,
|
|
default_to_auto_decoding_config=False,
|
|
update_mask=None,
|
|
):
|
|
|
|
return self._InitializeDecodingConfigRecognizerCommand(
|
|
recognition_config,
|
|
args.encoding,
|
|
args.sample_rate,
|
|
args.audio_channel_count,
|
|
default_to_auto_decoding_config=default_to_auto_decoding_config,
|
|
update_mask=update_mask,
|
|
)
|
|
|
|
def _InitializeDecodingConfigRecognizerCommand(
|
|
self,
|
|
recognition_config,
|
|
encoding,
|
|
sample_rate,
|
|
audio_channel_count,
|
|
default_to_auto_decoding_config=False,
|
|
update_mask=None,
|
|
):
|
|
"""Initializes encoding type based on auto (or explicit decoding option), sample rate and audio channel count."""
|
|
if encoding is not None:
|
|
if encoding == 'AUTO':
|
|
recognition_config.autoDecodingConfig = (
|
|
self._messages.AutoDetectDecodingConfig()
|
|
)
|
|
|
|
elif encoding in flag_validations.EXPLICIT_ENCODING_OPTIONS:
|
|
recognition_config.explicitDecodingConfig = (
|
|
self._messages.ExplicitDecodingConfig()
|
|
)
|
|
|
|
recognition_config.explicitDecodingConfig.encoding = (
|
|
self._encoding_to_message[encoding]
|
|
)
|
|
|
|
if sample_rate is not None:
|
|
recognition_config.explicitDecodingConfig.sampleRateHertz = (
|
|
sample_rate
|
|
)
|
|
|
|
if audio_channel_count is not None:
|
|
recognition_config.explicitDecodingConfig.audioChannelCount = (
|
|
audio_channel_count
|
|
)
|
|
else:
|
|
raise exceptions.InvalidArgumentException(
|
|
'--encoding',
|
|
'[--encoding] must be set to LINEAR16, MULAW, ALAW, or AUTO.',
|
|
)
|
|
elif default_to_auto_decoding_config:
|
|
recognition_config.autoDecodingConfig = (
|
|
self._messages.AutoDetectDecodingConfig()
|
|
)
|
|
|
|
if update_mask is None:
|
|
return recognition_config, update_mask
|
|
|
|
if encoding == 'AUTO':
|
|
update_mask.append('auto_decoding_config')
|
|
elif encoding in flag_validations.EXPLICIT_ENCODING_OPTIONS:
|
|
update_mask.append('explicit_decoding_config')
|
|
elif default_to_auto_decoding_config:
|
|
update_mask.append('auto_decoding_config')
|
|
if sample_rate is not None:
|
|
if recognition_config.explicitDecodingConfig is None:
|
|
recognition_config.explicitDecodingConfig = (
|
|
self._messages.ExplicitDecodingConfig()
|
|
)
|
|
recognition_config.explicitDecodingConfig.sampleRateHertz = sample_rate
|
|
update_mask.append('explicit_decoding_config.sample_rate_hertz')
|
|
if audio_channel_count is not None:
|
|
if recognition_config.explicitDecodingConfig is None:
|
|
recognition_config.explicitDecodingConfig = (
|
|
self._messages.ExplicitDecodingConfig()
|
|
)
|
|
recognition_config.explicitDecodingConfig.audioChannelCount = (
|
|
audio_channel_count
|
|
)
|
|
update_mask.append('explicit_decoding_config.audio_channel_count')
|
|
|
|
return recognition_config, update_mask
|
|
|
|
def InitializeAdaptationConfigFromArgs(
|
|
self,
|
|
args,
|
|
update_mask=None,
|
|
):
|
|
"""Initializes PhraseSets based on hints."""
|
|
return self._InitializeAdaptationConfigRecognizeRequest(
|
|
args.hint_phrases, args.hint_phrase_sets, args.hint_boost, update_mask
|
|
)
|
|
|
|
def _InitializeAdaptationConfigRecognizeRequest(
|
|
self, hint_phrases, hint_phrase_sets, hint_boost=5.0, update_mask=None
|
|
):
|
|
"""Initializes PhraseSets based on phrases and phrase sets."""
|
|
speech_adaptation_phrase_sets = []
|
|
|
|
if hint_phrases:
|
|
inline_phrase_set = self._messages.PhraseSet(
|
|
phrases=[
|
|
self._messages.Phrase(value=hint_phrase, boost=5.0)
|
|
for hint_phrase in hint_phrases
|
|
],
|
|
boost=hint_boost,
|
|
)
|
|
inline_adaptation_phrase_set = self._messages.AdaptationPhraseSet(
|
|
inlinePhraseSet=inline_phrase_set
|
|
)
|
|
speech_adaptation_phrase_sets.append(inline_adaptation_phrase_set)
|
|
|
|
if hint_phrase_sets:
|
|
for hint_phrase_set in hint_phrase_sets:
|
|
adaptation_phrase_set = self._messages.AdaptationPhraseSet(
|
|
phraseSet=hint_phrase_set
|
|
)
|
|
speech_adaptation_phrase_sets.append(adaptation_phrase_set)
|
|
|
|
speech_adaptation_config = self._messages.SpeechAdaptation(
|
|
phraseSets=speech_adaptation_phrase_sets
|
|
)
|
|
|
|
if update_mask is not None:
|
|
update_mask.append('adaptation')
|
|
return speech_adaptation_config, update_mask
|
|
|
|
def InitializeASRFeaturesFromArgs(
|
|
self,
|
|
args,
|
|
update_mask=None,
|
|
):
|
|
"""Collects features from the provided arguments."""
|
|
features_config = self._messages.RecognitionFeatures()
|
|
inner_update_mask = []
|
|
if args.profanity_filter is not None:
|
|
features_config.profanityFilter = args.profanity_filter
|
|
inner_update_mask.append('features.profanityFilter')
|
|
|
|
if args.enable_word_time_offsets is not None:
|
|
features_config.enableWordTimeOffsets = args.enable_word_time_offsets
|
|
inner_update_mask.append('features.enableWordTimeOffsets')
|
|
if args.enable_word_confidence is not None:
|
|
features_config.enableWordConfidence = args.enable_word_confidence
|
|
inner_update_mask.append('features.enableWordConfidence')
|
|
if args.enable_automatic_punctuation is not None:
|
|
features_config.enableAutomaticPunctuation = (
|
|
args.enable_automatic_punctuation
|
|
)
|
|
inner_update_mask.append('features.enableAutomaticPunctuation')
|
|
if args.enable_spoken_punctuation is not None:
|
|
features_config.enableSpokenPunctuation = args.enable_spoken_punctuation
|
|
inner_update_mask.append('features.enableSpokenPunctuation')
|
|
if args.enable_spoken_emojis is not None:
|
|
features_config.enableSpokenEmojis = args.enable_spoken_emojis
|
|
inner_update_mask.append('features.enableSpokenEmojis')
|
|
if (
|
|
args.min_speaker_count is not None
|
|
and args.max_speaker_count is not None
|
|
):
|
|
features_config.diarizationConfig = (
|
|
self._messages.SpeakerDiarizationConfig(
|
|
minSpeakerCount=args.min_speaker_count,
|
|
maxSpeakerCount=args.max_speaker_count,
|
|
)
|
|
)
|
|
inner_update_mask.append('features.diarizationConfig.minSpeakerCount')
|
|
inner_update_mask.append('features.diarizationConfig.maxSpeakerCount')
|
|
if args.separate_channel_recognition:
|
|
features_config.multiChannelMode = (
|
|
self._messages.RecognitionFeatures.MultiChannelModeValueValuesEnum.SEPARATE_RECOGNITION_PER_CHANNEL
|
|
)
|
|
inner_update_mask.append('features.multiChannelMode')
|
|
elif args.separate_channel_recognition is not None:
|
|
features_config.multiChannelMode = (
|
|
self._messages.RecognitionFeatures.MultiChannelModeValueValuesEnum.MULTI_CHANNEL_MODE_UNSPECIFIED
|
|
)
|
|
inner_update_mask.append('features.multiChannelMode')
|
|
if args.max_alternatives is not None:
|
|
features_config.maxAlternatives = args.max_alternatives
|
|
inner_update_mask.append('features.maxAlternatives')
|
|
|
|
if update_mask is not None:
|
|
update_mask.extend(inner_update_mask)
|
|
|
|
return features_config, update_mask
|