Allow for playing the audio in a streaming way.

This commit is contained in:
Laurent 2025-07-02 16:51:04 +02:00
parent 61206d78c8
commit 236df99003

View File

@ -2,16 +2,16 @@
# requires-python = ">=3.12" # requires-python = ">=3.12"
# dependencies = [ # dependencies = [
# "huggingface_hub", # "huggingface_hub",
# "moshi_mlx", # "moshi_mlx>=0.2.8",
# "numpy", # "numpy",
# "sounddevice", # "sounddevice",
# ] # ]
# /// # ///
import argparse import argparse
from dataclasses import dataclass
import json import json
from pathlib import Path from pathlib import Path
import queue
import time import time
import numpy as np import numpy as np
@ -104,6 +104,7 @@ def main():
cfg_is_no_prefix = True cfg_is_no_prefix = True
mimi = tts_model.mimi mimi = tts_model.mimi
log("info", f"reading input from {args.inp}")
with open(args.inp, "r") as fobj: with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip() text_to_tts = fobj.read().strip()
@ -114,39 +115,56 @@ def main():
voices = [] voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
wav_frames = queue.Queue()
def _on_audio_hook(audio_tokens):
if (audio_tokens == -1).any():
return
_pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
def run():
log("info", "starting the inference loop")
begin = time.time() begin = time.time()
result = tts_model.generate( result = tts_model.generate(
all_entries, all_attributes, all_entries,
cfg_is_no_prefix=cfg_is_no_prefix, cfg_is_no_text=cfg_is_no_text) all_attributes,
cfg_is_no_prefix=cfg_is_no_prefix,
cfg_is_no_text=cfg_is_no_text,
on_audio_hook=_on_audio_hook,
)
frames = mx.concat(result.frames, axis=-1) frames = mx.concat(result.frames, axis=-1)
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
time_taken = time.time() - begin time_taken = time.time() - begin
total_speed = total_duration / time_taken total_speed = total_duration / time_taken
log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x") log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
return result
wav_frames = []
for frame in result.frames:
# We are processing frames one by one, although we could group them to improve speed.
_pcm = tts_model.mimi.decode_step(frame)
wav_frames.append(_pcm)
if args.out == "-": if args.out == "-":
cnt = [0]
def audio_callback(outdata, _a, _b, _c): def audio_callback(outdata, _a, _b, _c):
if cnt[0] < len(wav_frames): try:
outdata[:, 0] = wav_frames[cnt[0]][0, 0] pcm_data = wav_frames.get(block=False)
cnt[0] += 1 outdata[:, 0] = pcm_data
else: except queue.Empty:
outdata[:] = 0 outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate, with sd.OutputStream(samplerate=mimi.sample_rate,
blocksize=1920, blocksize=1920,
channels=1, channels=1,
callback=audio_callback): callback=audio_callback):
time.sleep(10) run()
time.sleep(3)
while True:
if wav_frames.qsize() == 0:
break
time.sleep(1)
else: else:
wavs = mx.concat(wav_frames, axis=-1) frames = []
end_step = result.end_steps[0] while True:
wav_length = int((mimi.sample_rate * (end_step + tts_model.final_padding) / mimi.frame_rate)) try:
wav = np.array(mx.clip(wavs[0, :, :wav_length], -1, 1)) frames.append(wav_frames.get_nowait())
except queue.Empty:
break
wav = np.concat(frames, -1)
sphn.write_wav(args.out, wav, mimi.sample_rate) sphn.write_wav(args.out, wav, mimi.sample_rate)