Mitigate TTS ResultStream leak in pipeline (#173290)

This commit is contained in:
Michael Hansen
2026-06-08 11:13:53 -05:00
committed by GitHub
parent 37e4f1ab32
commit e38e6ecec8
3 changed files with 103 additions and 4 deletions
@@ -1816,6 +1816,11 @@ class PipelineInput:
await self.run.text_to_speech(tts_input)
except PipelineError as err:
if self.run.tts_stream:
# Clean up TTS stream
self.run.tts_stream.delete()
self.run.tts_stream = None
self.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
@@ -1885,15 +1890,17 @@ class PipelineInput:
):
prepare_tasks.append(self.run.prepare_recognize_intent(self.session))
if prepare_tasks:
await asyncio.gather(*prepare_tasks)
# Do TTS prepare separately so we don't create a ResultStream if the
# pipeline is invalid.
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS)
<= end_stage_index
):
prepare_tasks.append(self.run.prepare_text_to_speech())
if prepare_tasks:
await asyncio.gather(*prepare_tasks)
await self.run.prepare_text_to_speech()
class PipelinePreferred(CollectionError):
+9
View File
@@ -613,6 +613,10 @@ class ResultStream:
async for chunk in converted_audio:
yield chunk
def delete(self) -> None:
"""Remove the result stream from the manager."""
self._manager.async_delete_result_stream(self.token)
def _hash_options(options: dict) -> str:
"""Hashes an options dictionary."""
@@ -809,6 +813,11 @@ class SpeechManager:
stream.last_used = monotonic()
return stream
@callback
def async_delete_result_stream(self, token: str) -> None:
"""Delete a result stream given a token."""
self.token_to_stream.pop(token, None)
@callback
def async_create_result_stream(
self,
@@ -2274,3 +2274,86 @@ async def test_stt_vad_enabled_based_on_audio_processing(
# VAD should NOT be created when requires_external_vad is False
mock_vad.assert_not_called()
async def test_invalid_pipeline_does_not_create_tts_stream(
hass: HomeAssistant,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
) -> None:
"""Test that an invalid pipeline won't create a TTS ResultStream."""
pipeline = async_get_pipeline(hass, None)
await async_update_pipeline(hass, pipeline, stt_engine="does-not-exist")
async def audio_data() -> AsyncGenerator[bytes]:
yield make_10ms_chunk(b"not used")
with patch.object(
mock_wake_word_provider_entity,
"async_process_audio_stream",
side_effect=assist_pipeline.error.WakeWordTimeoutError(
code="timeout", message="timeout"
),
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=lambda event: None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert len(hass.data[tts.DATA_TTS_MANAGER].token_to_stream) == 0
async def test_pipeline_error_before_tts_does_not_leak_result_stream(
hass: HomeAssistant,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
) -> None:
"""Test that a pipeline error before TTS will not leak a ResultStream."""
async def audio_data() -> AsyncGenerator[bytes]:
yield make_10ms_chunk(b"not used")
with patch.object(
mock_wake_word_provider_entity,
"async_process_audio_stream",
side_effect=assist_pipeline.error.WakeWordTimeoutError(
code="timeout", message="timeout"
),
):
for i in range(10):
with patch("secrets.token_urlsafe", return_value=f"mocked-token-{i}"):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=lambda event: None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert len(hass.data[tts.DATA_TTS_MANAGER].token_to_stream) == 0