From b2416b19dd61e4d8da15db7c758ccbbf07299f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20D=C3=A9fossez?= Date: Thu, 3 Jul 2025 14:38:25 +0200 Subject: [PATCH] changing streaming to be robust to repeated generation --- scripts/tts_pytorch.py | 9 +++++---- tts_pytorch.ipynb | 10 +++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index c3e7f1b..c485614 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -51,7 +51,6 @@ def main(): tts_model = TTSModel.from_checkpoint_info( checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half ) - tts_model.mimi.streaming_forever(batch_size=1) if args.inp == "-": if sys.stdin.isatty(): # Interactive @@ -61,7 +60,7 @@ def main(): with open(args.inp, "r") as fobj: text = fobj.read().strip() - # You could also generate multiple audios at once by passing a list of texts. + # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...] entries = tts_model.prepare_script([text], padding_between=1) voice_path = tts_model.get_voice_path(args.voice) # CFG coef goes here because the model was trained with CFG distillation, @@ -76,6 +75,7 @@ def main(): pcms = queue.Queue() + @torch.no_grad() def _on_frame(frame): if (frame != -1).all(): pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() @@ -94,7 +94,8 @@ def main(): channels=1, callback=audio_callback, ): - tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) + with tts_model.mimi.streaming(1): + tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) time.sleep(3) while True: if pcms.qsize() == 0: @@ -102,7 +103,7 @@ def main(): time.sleep(1) else: result = tts_model.generate([entries], [condition_attributes]) - with torch.no_grad(): + with tts_model.mimi.streaming(1), torch.no_grad(): pcms = [] for frame in result.frames[tts_model.delay_steps :]: pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb index d3b6daf..27f2089 100644 --- a/tts_pytorch.ipynb +++ b/tts_pytorch.ipynb @@ -54,13 +54,13 @@ "tts_model = TTSModel.from_checkpoint_info(\n", " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", ")\n", - "tts_model.mimi.streaming_forever(1)\n", "\n", - "# You could also generate multiple audios at once by passing a list of texts.\n", + "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n", "entries = tts_model.prepare_script([text], padding_between=1)\n", "voice_path = tts_model.get_voice_path(voice)\n", "# CFG coef goes here because the model was trained with CFG distillation,\n", "# so it's not _actually_ doing CFG at inference time.\n", + "# Also, if you are generating a dialog, you should have at least two voices in the list.\n", "condition_attributes = tts_model.make_condition_attributes(\n", " [voice_path], cfg_coef=2.0\n", ")" @@ -82,7 +82,11 @@ " pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n", " pcms.append(np.clip(pcm[0, 0], -1, 1))\n", "\n", - "result = tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)\n", + "# You could also generate multiple audios at once by extending the following lists.\n", + "all_entries = [entries]\n", + "all_condition_attributes = [condition_attributes]\n", + "with tts_model.mimi.streaming(len(all_entries)):\n", + " result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n", "\n", "print(\"Done generating.\")\n", "audio = np.concatenate(pcms, axis=-1)"