diff --git a/scripts/tts_mlx.py b/scripts/tts_mlx.py index 70dca65..9d89295 100644 --- a/scripts/tts_mlx.py +++ b/scripts/tts_mlx.py @@ -2,7 +2,7 @@ # requires-python = ">=3.12" # dependencies = [ # "huggingface_hub", -# "moshi_mlx>=0.2.8", +# "moshi_mlx @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx", # "numpy", # "sounddevice", # ] @@ -121,10 +121,10 @@ def main(): all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] wav_frames = queue.Queue() - def _on_audio_hook(audio_tokens): - if (audio_tokens == -1).any(): + def _on_frame(frame): + if (frame == -1).any(): return - _pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None]) + _pcm = tts_model.mimi.decode_step(frame[:, :, None]) _pcm = np.array(mx.clip(_pcm[0, 0], -1, 1)) wav_frames.put_nowait(_pcm) @@ -136,7 +136,7 @@ def main(): all_attributes, cfg_is_no_prefix=cfg_is_no_prefix, cfg_is_no_text=cfg_is_no_text, - on_audio_hook=_on_audio_hook, + on_frame=_on_frame, ) frames = mx.concat(result.frames, axis=-1) total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate @@ -163,6 +163,7 @@ def main(): break time.sleep(1) else: + run() frames = [] while True: try: