From a2f031deb52737584db5611176a16c28850d380d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 16 Jul 2025 21:07:02 +0200 Subject: [PATCH] Fix the pytorch tts streaming example. (#84) * Fix the pytorch tts streaming example. * Edit the readme too. --- README.md | 8 ++++++++ scripts/tts_pytorch_streaming.py | 24 ++++++++++++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bd6d97d..8df1709 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,14 @@ echo "Hey, how are you?" | python scripts/tts_pytorch.py - - python scripts/tts_pytorch.py text_to_say.txt audio_output.wav ``` +The `tts_pytorch.py` script waits for all the text to be available before +starting the audio generation. A fully streaming implementation is available in +the `tts_pytorch_streaming.py` script, which can be used as follows: + +```bash +echo "Hey, how are you?" | python scripts/tts_pytorch_streaming.py audio_output.wav +``` + This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip. If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and just prefix the command above with `uvx --with moshi`. diff --git a/scripts/tts_pytorch_streaming.py b/scripts/tts_pytorch_streaming.py index d808023..36e2799 100644 --- a/scripts/tts_pytorch_streaming.py +++ b/scripts/tts_pytorch_streaming.py @@ -21,13 +21,27 @@ from moshi.models.loaders import CheckpointInfo from moshi.conditioners import dropout_all_conditions from moshi.models.lm import LMGen from moshi.models.tts import ( + Entry, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel, ConditionAttributes, + script_to_entries, ) +def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]: + multi_speaker = first_turn and model.multi_speaker + return script_to_entries( + model.tokenizer, + model.machine.token_ids, + model.mimi.frame_rate, + [script], + multi_speaker=multi_speaker, + padding_between=1, + ) + + def _make_null( all_attributes: tp.Sequence[ConditionAttributes], ) -> list[ConditionAttributes]: @@ -206,9 +220,10 @@ def main(): channels=1, callback=audio_callback, ) and tts_model.mimi.streaming(1): + first_turn = True for line in sys.stdin: - # TODO: Fix the following to only include bos on the first line. - entries = tts_model.prepare_script([line.strip()], padding_between=1) + entries = prepare_script(tts_model, line.strip(), first_turn=first_turn) + first_turn = False for entry in entries: gen.append_entry(entry) gen.process() @@ -227,9 +242,10 @@ def main(): gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame) with tts_model.mimi.streaming(1): + first_turn = True for line in sys.stdin: - # TODO: Fix the following to only include bos on the first line. - entries = tts_model.prepare_script([line.strip()], padding_between=1) + entries = prepare_script(tts_model, line.strip(), first_turn=first_turn) + first_turn = False for entry in entries: gen.append_entry(entry) gen.process()