From 236df99003cc48e65b4c4a3d8c1a3e241834f97a Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 2 Jul 2025 16:51:04 +0200 Subject: [PATCH] Allow for playing the audio in a streaming way. --- scripts/tts_mlx.py | 70 +++++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py index 9b72bfc..050fce5 100644 --- a/scripts/tts_mlx.py +++ b/scripts/tts_mlx.py @@ -2,16 +2,16 @@ # requires-python = ">=3.12" # dependencies = [ # "huggingface_hub", -# "moshi_mlx", +# "moshi_mlx>=0.2.8", # "numpy", # "sounddevice", # ] # /// import argparse -from dataclasses import dataclass import json from pathlib import Path +import queue import time import numpy as np @@ -104,6 +104,7 @@ def main(): cfg_is_no_prefix = True mimi = tts_model.mimi + log("info", f"reading input from {args.inp}") with open(args.inp, "r") as fobj: text_to_tts = fobj.read().strip() @@ -114,39 +115,56 @@ def main(): voices = [] all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] - begin = time.time() - result = tts_model.generate( - all_entries, all_attributes, - cfg_is_no_prefix=cfg_is_no_prefix, cfg_is_no_text=cfg_is_no_text) - frames = mx.concat(result.frames, axis=-1) - total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate - time_taken = time.time() - begin - total_speed = total_duration / time_taken - log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x") + wav_frames = queue.Queue() + def _on_audio_hook(audio_tokens): + if (audio_tokens == -1).any(): + return + _pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None]) + _pcm = np.array(mx.clip(_pcm[0, 0], -1, 1)) + wav_frames.put_nowait(_pcm) + + def run(): + log("info", "starting the inference loop") + begin = time.time() + result = tts_model.generate( + all_entries, + all_attributes, + cfg_is_no_prefix=cfg_is_no_prefix, + cfg_is_no_text=cfg_is_no_text, + on_audio_hook=_on_audio_hook, + ) + frames = mx.concat(result.frames, axis=-1) + total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate + time_taken = time.time() - begin + total_speed = total_duration / time_taken + log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x") + return result - wav_frames = [] - for frame in result.frames: - # We are processing frames one by one, although we could group them to improve speed. - _pcm = tts_model.mimi.decode_step(frame) - wav_frames.append(_pcm) if args.out == "-": - cnt = [0] def audio_callback(outdata, _a, _b, _c): - if cnt[0] < len(wav_frames): - outdata[:, 0] = wav_frames[cnt[0]][0, 0] - cnt[0] += 1 - else: + try: + pcm_data = wav_frames.get(block=False) + outdata[:, 0] = pcm_data + except queue.Empty: outdata[:] = 0 with sd.OutputStream(samplerate=mimi.sample_rate, blocksize=1920, channels=1, callback=audio_callback): - time.sleep(10) + run() + time.sleep(3) + while True: + if wav_frames.qsize() == 0: + break + time.sleep(1) else: - wavs = mx.concat(wav_frames, axis=-1) - end_step = result.end_steps[0] - wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate)) - wav = np.array(mx.clip(wavs[0, :, :wav_length], -1, 1)) + frames = [] + while True: + try: + frames.append(wav_frames.get_nowait()) + except queue.Empty: + break + wav = np.concat(frames, -1) sphn.write_wav(args.out, wav, mimi.sample_rate)