Use the on_frame callback in the mlx tts example. (#34)

This commit is contained in:
Laurent Mazare 2025-07-03 09:29:04 +02:00 committed by GitHub
parent 6c1e9f12cf
commit d92e4c2695
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx>=0.2.8",
# "moshi_mlx @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx",
# "numpy",
# "sounddevice",
# ]
@ -121,10 +121,10 @@ def main():
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
wav_frames = queue.Queue()
def _on_audio_hook(audio_tokens):
if (audio_tokens == -1).any():
def _on_frame(frame):
if (frame == -1).any():
return
_pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
@ -136,7 +136,7 @@ def main():
all_attributes,
cfg_is_no_prefix=cfg_is_no_prefix,
cfg_is_no_text=cfg_is_no_text,
on_audio_hook=_on_audio_hook,
on_frame=_on_frame,
)
frames = mx.concat(result.frames, axis=-1)
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
@ -163,6 +163,7 @@ def main():
break
time.sleep(1)
else:
run()
frames = []
while True:
try: