Mitigate TTS ResultStream leak in pipeline (#173290)
This commit is contained in:
@@ -1816,6 +1816,11 @@ class PipelineInput:
|
||||
await self.run.text_to_speech(tts_input)
|
||||
|
||||
except PipelineError as err:
|
||||
if self.run.tts_stream:
|
||||
# Clean up TTS stream
|
||||
self.run.tts_stream.delete()
|
||||
self.run.tts_stream = None
|
||||
|
||||
self.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
@@ -1885,15 +1890,17 @@ class PipelineInput:
|
||||
):
|
||||
prepare_tasks.append(self.run.prepare_recognize_intent(self.session))
|
||||
|
||||
if prepare_tasks:
|
||||
await asyncio.gather(*prepare_tasks)
|
||||
|
||||
# Do TTS prepare separately so we don't create a ResultStream if the
|
||||
# pipeline is invalid.
|
||||
if (
|
||||
start_stage_index
|
||||
<= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS)
|
||||
<= end_stage_index
|
||||
):
|
||||
prepare_tasks.append(self.run.prepare_text_to_speech())
|
||||
|
||||
if prepare_tasks:
|
||||
await asyncio.gather(*prepare_tasks)
|
||||
await self.run.prepare_text_to_speech()
|
||||
|
||||
|
||||
class PipelinePreferred(CollectionError):
|
||||
|
||||
@@ -613,6 +613,10 @@ class ResultStream:
|
||||
async for chunk in converted_audio:
|
||||
yield chunk
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Remove the result stream from the manager."""
|
||||
self._manager.async_delete_result_stream(self.token)
|
||||
|
||||
|
||||
def _hash_options(options: dict) -> str:
|
||||
"""Hashes an options dictionary."""
|
||||
@@ -809,6 +813,11 @@ class SpeechManager:
|
||||
stream.last_used = monotonic()
|
||||
return stream
|
||||
|
||||
@callback
|
||||
def async_delete_result_stream(self, token: str) -> None:
|
||||
"""Delete a result stream given a token."""
|
||||
self.token_to_stream.pop(token, None)
|
||||
|
||||
@callback
|
||||
def async_create_result_stream(
|
||||
self,
|
||||
|
||||
@@ -2274,3 +2274,86 @@ async def test_stt_vad_enabled_based_on_audio_processing(
|
||||
|
||||
# VAD should NOT be created when requires_external_vad is False
|
||||
mock_vad.assert_not_called()
|
||||
|
||||
|
||||
async def test_invalid_pipeline_does_not_create_tts_stream(
|
||||
hass: HomeAssistant,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test that an invalid pipeline won't create a TTS ResultStream."""
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
await async_update_pipeline(hass, pipeline, stt_engine="does-not-exist")
|
||||
|
||||
async def audio_data() -> AsyncGenerator[bytes]:
|
||||
yield make_10ms_chunk(b"not used")
|
||||
|
||||
with patch.object(
|
||||
mock_wake_word_provider_entity,
|
||||
"async_process_audio_stream",
|
||||
side_effect=assist_pipeline.error.WakeWordTimeoutError(
|
||||
code="timeout", message="timeout"
|
||||
),
|
||||
):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=lambda event: None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
)
|
||||
|
||||
assert len(hass.data[tts.DATA_TTS_MANAGER].token_to_stream) == 0
|
||||
|
||||
|
||||
async def test_pipeline_error_before_tts_does_not_leak_result_stream(
|
||||
hass: HomeAssistant,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test that a pipeline error before TTS will not leak a ResultStream."""
|
||||
|
||||
async def audio_data() -> AsyncGenerator[bytes]:
|
||||
yield make_10ms_chunk(b"not used")
|
||||
|
||||
with patch.object(
|
||||
mock_wake_word_provider_entity,
|
||||
"async_process_audio_stream",
|
||||
side_effect=assist_pipeline.error.WakeWordTimeoutError(
|
||||
code="timeout", message="timeout"
|
||||
),
|
||||
):
|
||||
for i in range(10):
|
||||
with patch("secrets.token_urlsafe", return_value=f"mocked-token-{i}"):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=lambda event: None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
)
|
||||
|
||||
assert len(hass.data[tts.DATA_TTS_MANAGER].token_to_stream) == 0
|
||||
|
||||
Reference in New Issue
Block a user