changing streaming to be robust to repeated generation

This commit is contained in:
Alexandre Défossez 2025-07-03 14:38:25 +02:00
parent c1d248abba
commit b2416b19dd
2 changed files with 12 additions and 7 deletions

View File

@ -51,7 +51,6 @@ def main():
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
) )
tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
@ -61,7 +60,7 @@ def main():
with open(args.inp, "r") as fobj: with open(args.inp, "r") as fobj:
text = fobj.read().strip() text = fobj.read().strip()
# You could also generate multiple audios at once by passing a list of texts. # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1) entries = tts_model.prepare_script([text], padding_between=1)
voice_path = tts_model.get_voice_path(args.voice) voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation, # CFG coef goes here because the model was trained with CFG distillation,
@ -76,6 +75,7 @@ def main():
pcms = queue.Queue() pcms = queue.Queue()
@torch.no_grad()
def _on_frame(frame): def _on_frame(frame):
if (frame != -1).all(): if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
@ -94,6 +94,7 @@ def main():
channels=1, channels=1,
callback=audio_callback, callback=audio_callback,
): ):
with tts_model.mimi.streaming(1):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3) time.sleep(3)
while True: while True:
@ -102,7 +103,7 @@ def main():
time.sleep(1) time.sleep(1)
else: else:
result = tts_model.generate([entries], [condition_attributes]) result = tts_model.generate([entries], [condition_attributes])
with torch.no_grad(): with tts_model.mimi.streaming(1), torch.no_grad():
pcms = [] pcms = []
for frame in result.frames[tts_model.delay_steps :]: for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()

View File

@ -54,13 +54,13 @@
"tts_model = TTSModel.from_checkpoint_info(\n", "tts_model = TTSModel.from_checkpoint_info(\n",
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
")\n", ")\n",
"tts_model.mimi.streaming_forever(1)\n",
"\n", "\n",
"# You could also generate multiple audios at once by passing a list of texts.\n", "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
"entries = tts_model.prepare_script([text], padding_between=1)\n", "entries = tts_model.prepare_script([text], padding_between=1)\n",
"voice_path = tts_model.get_voice_path(voice)\n", "voice_path = tts_model.get_voice_path(voice)\n",
"# CFG coef goes here because the model was trained with CFG distillation,\n", "# CFG coef goes here because the model was trained with CFG distillation,\n",
"# so it's not _actually_ doing CFG at inference time.\n", "# so it's not _actually_ doing CFG at inference time.\n",
"# Also, if you are generating a dialog, you should have at least two voices in the list.\n",
"condition_attributes = tts_model.make_condition_attributes(\n", "condition_attributes = tts_model.make_condition_attributes(\n",
" [voice_path], cfg_coef=2.0\n", " [voice_path], cfg_coef=2.0\n",
")" ")"
@ -82,7 +82,11 @@
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n", " pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n", " pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
"\n", "\n",
"result = tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)\n", "# You could also generate multiple audios at once by extending the following lists.\n",
"all_entries = [entries]\n",
"all_condition_attributes = [condition_attributes]\n",
"with tts_model.mimi.streaming(len(all_entries)):\n",
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
"\n", "\n",
"print(\"Done generating.\")\n", "print(\"Done generating.\")\n",
"audio = np.concatenate(pcms, axis=-1)" "audio = np.concatenate(pcms, axis=-1)"