From 61206d78c83a261643275c0c89d8bce67522d12b Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 2 Jul 2025 15:49:31 +0200 Subject: [PATCH] Audio playback. --- scripts/tts_mlx.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py index 95c59e4..9b72bfc 100644 --- a/scripts/tts_mlx.py +++ b/scripts/tts_mlx.py @@ -19,6 +19,9 @@ import mlx.core as mx import mlx.nn as nn import sentencepiece import sphn +import time + +import sounddevice as sd from moshi_mlx.client_utils import make_log from moshi_mlx import models @@ -126,11 +129,25 @@ def main(): # We are processing frames one by one, although we could group them to improve speed. _pcm = tts_model.mimi.decode_step(frame) wav_frames.append(_pcm) - wavs = mx.concat(wav_frames, axis=-1) - end_step = result.end_steps[0] - wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate)) - wav = wavs[0, :, :wav_length] - sphn.write_wav(args.out, np.array(mx.clip(wav, -1, 1)), mimi.sample_rate) + if args.out == "-": + cnt = [0] + def audio_callback(outdata, _a, _b, _c): + if cnt[0] < len(wav_frames): + outdata[:, 0] = wav_frames[cnt[0]][0, 0] + cnt[0] += 1 + else: + outdata[:] = 0 + with sd.OutputStream(samplerate=mimi.sample_rate, + blocksize=1920, + channels=1, + callback=audio_callback): + time.sleep(10) + else: + wavs = mx.concat(wav_frames, axis=-1) + end_step = result.end_steps[0] + wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate)) + wav = np.array(mx.clip(wavs[0, :, :wav_length], -1, 1)) + sphn.write_wav(args.out, wav, mimi.sample_rate) if __name__ == "__main__":