Compare commits

...

6 Commits

Author SHA1 Message Date
Alexandre Défossez
5afa2fe656 plop 2025-07-03 15:06:28 +02:00
Alexandre Défossez
774ef275a4 plop 2025-07-03 15:04:21 +02:00
Alexandre Défossez
080f349403 plop 2025-07-03 15:04:01 +02:00
Alexandre Défossez
0b5a5fbed2 plop 2025-07-03 14:55:24 +02:00
Alexandre Défossez
31f8746881 some changes 2025-07-03 14:52:02 +02:00
Alexandre Défossez
b2416b19dd changing streaming to be robust to repeated generation 2025-07-03 14:38:25 +02:00
2 changed files with 17 additions and 9 deletions

View File

@ -51,7 +51,6 @@ def main():
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
) )
tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
@ -61,11 +60,12 @@ def main():
with open(args.inp, "r") as fobj: with open(args.inp, "r") as fobj:
text = fobj.read().strip() text = fobj.read().strip()
# You could also generate multiple audios at once by passing a list of texts. # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1) entries = tts_model.prepare_script([text], padding_between=1)
voice_path = tts_model.get_voice_path(args.voice) voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation, # CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time. # so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes( condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
@ -94,7 +94,8 @@ def main():
channels=1, channels=1,
callback=audio_callback, callback=audio_callback,
): ):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) with tts_model.mimi.streaming(1):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
time.sleep(3) time.sleep(3)
while True: while True:
if pcms.qsize() == 0: if pcms.qsize() == 0:
@ -102,7 +103,7 @@ def main():
time.sleep(1) time.sleep(1)
else: else:
result = tts_model.generate([entries], [condition_attributes]) result = tts_model.generate([entries], [condition_attributes])
with torch.no_grad(): with tts_model.mimi.streaming(1), torch.no_grad():
pcms = [] pcms = []
for frame in result.frames[tts_model.delay_steps :]: for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()

View File

@ -7,7 +7,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install \"moshi==0.2.7\"" "# Fast install, might break in the future.\n",
"!pip install 'sphn<0.2'\n",
"!pip install --no-deps \"moshi==0.2.7\"\n",
"# Slow install (will download torch and cuda), but future proof.\n",
"# !pip install \"moshi==0.2.7\""
] ]
}, },
{ {
@ -21,7 +25,6 @@
"import sys\n", "import sys\n",
"\n", "\n",
"import numpy as np\n", "import numpy as np\n",
"import sphn\n",
"import torch\n", "import torch\n",
"from moshi.models.loaders import CheckpointInfo\n", "from moshi.models.loaders import CheckpointInfo\n",
"from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n", "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
@ -54,13 +57,13 @@
"tts_model = TTSModel.from_checkpoint_info(\n", "tts_model = TTSModel.from_checkpoint_info(\n",
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
")\n", ")\n",
"tts_model.mimi.streaming_forever(1)\n",
"\n", "\n",
"# You could also generate multiple audios at once by passing a list of texts.\n", "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
"entries = tts_model.prepare_script([text], padding_between=1)\n", "entries = tts_model.prepare_script([text], padding_between=1)\n",
"voice_path = tts_model.get_voice_path(voice)\n", "voice_path = tts_model.get_voice_path(voice)\n",
"# CFG coef goes here because the model was trained with CFG distillation,\n", "# CFG coef goes here because the model was trained with CFG distillation,\n",
"# so it's not _actually_ doing CFG at inference time.\n", "# so it's not _actually_ doing CFG at inference time.\n",
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
"condition_attributes = tts_model.make_condition_attributes(\n", "condition_attributes = tts_model.make_condition_attributes(\n",
" [voice_path], cfg_coef=2.0\n", " [voice_path], cfg_coef=2.0\n",
")" ")"
@ -82,7 +85,11 @@
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n", " pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n", " pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
"\n", "\n",
"result = tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)\n", "# You could also generate multiple audios at once by extending the following lists.\n",
"all_entries = [entries]\n",
"all_condition_attributes = [condition_attributes]\n",
"with tts_model.mimi.streaming(len(all_entries)):\n",
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
"\n", "\n",
"print(\"Done generating.\")\n", "print(\"Done generating.\")\n",
"audio = np.concatenate(pcms, axis=-1)" "audio = np.concatenate(pcms, axis=-1)"