Some updates to the colab and script (#38)

* changing streaming to be robust to repeated generation

* some changes

* plop

* plop

* plop

* plop
This commit is contained in:
Alexandre Défossez 2025-07-03 15:06:37 +02:00 committed by GitHub
parent c1d248abba
commit eae5e17975
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 9 deletions

View File

@ -51,7 +51,6 @@ def main():
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
) )
tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
@ -61,11 +60,12 @@ def main():
with open(args.inp, "r") as fobj: with open(args.inp, "r") as fobj:
text = fobj.read().strip() text = fobj.read().strip()
# You could also generate multiple audios at once by passing a list of texts. # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1) entries = tts_model.prepare_script([text], padding_between=1)
voice_path = tts_model.get_voice_path(args.voice) voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation, # CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time. # so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes( condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
@ -94,7 +94,8 @@ def main():
channels=1, channels=1,
callback=audio_callback, callback=audio_callback,
): ):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) with tts_model.mimi.streaming(1):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3) time.sleep(3)
while True: while True:
if pcms.qsize() == 0: if pcms.qsize() == 0:
@ -102,7 +103,7 @@ def main():
time.sleep(1) time.sleep(1)
else: else:
result = tts_model.generate([entries], [condition_attributes]) result = tts_model.generate([entries], [condition_attributes])
with torch.no_grad(): with tts_model.mimi.streaming(1), torch.no_grad():
pcms = [] pcms = []
for frame in result.frames[tts_model.delay_steps :]: for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()

View File

@ -7,7 +7,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install \"moshi==0.2.7\"" "# Fast install, might break in the future.\n",
"!pip install 'sphn<0.2'\n",
"!pip install --no-deps \"moshi==0.2.7\"\n",
"# Slow install (will download torch and cuda), but future proof.\n",
"# !pip install \"moshi==0.2.7\""
] ]
}, },
{ {
@ -21,7 +25,6 @@
"import sys\n", "import sys\n",
"\n", "\n",
"import numpy as np\n", "import numpy as np\n",
"import sphn\n",
"import torch\n", "import torch\n",
"from moshi.models.loaders import CheckpointInfo\n", "from moshi.models.loaders import CheckpointInfo\n",
"from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n", "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
@ -54,13 +57,13 @@
"tts_model = TTSModel.from_checkpoint_info(\n", "tts_model = TTSModel.from_checkpoint_info(\n",
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
")\n", ")\n",
"tts_model.mimi.streaming_forever(1)\n",
"\n", "\n",
"# You could also generate multiple audios at once by passing a list of texts.\n", "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
"entries = tts_model.prepare_script([text], padding_between=1)\n", "entries = tts_model.prepare_script([text], padding_between=1)\n",
"voice_path = tts_model.get_voice_path(voice)\n", "voice_path = tts_model.get_voice_path(voice)\n",
"# CFG coef goes here because the model was trained with CFG distillation,\n", "# CFG coef goes here because the model was trained with CFG distillation,\n",
"# so it's not _actually_ doing CFG at inference time.\n", "# so it's not _actually_ doing CFG at inference time.\n",
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
"condition_attributes = tts_model.make_condition_attributes(\n", "condition_attributes = tts_model.make_condition_attributes(\n",
" [voice_path], cfg_coef=2.0\n", " [voice_path], cfg_coef=2.0\n",
")" ")"
@ -82,7 +85,11 @@
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n", " pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n", " pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
"\n", "\n",
"result = tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)\n", "# You could also generate multiple audios at once by extending the following lists.\n",
"all_entries = [entries]\n",
"all_condition_attributes = [condition_attributes]\n",
"with tts_model.mimi.streaming(len(all_entries)):\n",
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
"\n", "\n",
"print(\"Done generating.\")\n", "print(\"Done generating.\")\n",
"audio = np.concatenate(pcms, axis=-1)" "audio = np.concatenate(pcms, axis=-1)"