From cecbe46d4bbdd8eb5f07ec17442bb1c5181dd0e4 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 16 Jul 2025 21:01:16 +0200 Subject: [PATCH] Fix the pytorch tts streaming example. --- scripts/tts_pytorch_streaming.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/scripts/tts_pytorch_streaming.py b/scripts/tts_pytorch_streaming.py index d808023..36e2799 100644 --- a/scripts/tts_pytorch_streaming.py +++ b/scripts/tts_pytorch_streaming.py @@ -21,13 +21,27 @@ from moshi.models.loaders import CheckpointInfo from moshi.conditioners import dropout_all_conditions from moshi.models.lm import LMGen from moshi.models.tts import ( + Entry, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel, ConditionAttributes, + script_to_entries, ) +def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]: + multi_speaker = first_turn and model.multi_speaker + return script_to_entries( + model.tokenizer, + model.machine.token_ids, + model.mimi.frame_rate, + [script], + multi_speaker=multi_speaker, + padding_between=1, + ) + + def _make_null( all_attributes: tp.Sequence[ConditionAttributes], ) -> list[ConditionAttributes]: @@ -206,9 +220,10 @@ def main(): channels=1, callback=audio_callback, ) and tts_model.mimi.streaming(1): + first_turn = True for line in sys.stdin: - # TODO: Fix the following to only include bos on the first line. - entries = tts_model.prepare_script([line.strip()], padding_between=1) + entries = prepare_script(tts_model, line.strip(), first_turn=first_turn) + first_turn = False for entry in entries: gen.append_entry(entry) gen.process() @@ -227,9 +242,10 @@ def main(): gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame) with tts_model.mimi.streaming(1): + first_turn = True for line in sys.stdin: - # TODO: Fix the following to only include bos on the first line. - entries = tts_model.prepare_script([line.strip()], padding_between=1) + entries = prepare_script(tts_model, line.strip(), first_turn=first_turn) + first_turn = False for entry in entries: gen.append_entry(entry) gen.process()