From 5f8e924176f479c560dbb9779849339950888906 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 3 Jul 2025 11:05:06 +0200 Subject: [PATCH] Streaming output for the pytorch tts example. (#33) * Streaming output for the pytorch tts example. * Run the pre-commit hooks. --- scripts/stt_from_file_with_prompt_pytorch.py | 1 - scripts/tts_pytorch.py | 62 ++++++++++++------- tts_pytorch.ipynb | 64 +++++--------------- 3 files changed, 55 insertions(+), 72 deletions(-) diff --git a/scripts/stt_from_file_with_prompt_pytorch.py b/scripts/stt_from_file_with_prompt_pytorch.py index 63fe748..6345303 100644 --- a/scripts/stt_from_file_with_prompt_pytorch.py +++ b/scripts/stt_from_file_with_prompt_pytorch.py @@ -1,7 +1,6 @@ """An example script that illustrates how one can prompt Kyutai STT models.""" import argparse -import dataclasses import itertools import math from collections import deque diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 2fc7052..515860e 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -11,20 +11,14 @@ import argparse import sys import numpy as np +import queue import sphn +import time import torch from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel -def play_audio(audio: np.ndarray, sample_rate: int): - # Requires the Portaudio library which might not be available in all environments. - import sounddevice as sd - - with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1): - sd.play(audio, sample_rate) - - def main(): parser = argparse.ArgumentParser( description="Run Kyutai TTS using the PyTorch implementation" @@ -57,6 +51,7 @@ def main(): tts_model = TTSModel.from_checkpoint_info( checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half ) + tts_model.mimi.streaming_forever(batch_size=1) if args.inp == "-": if sys.stdin.isatty(): # Interactive @@ -75,22 +70,45 @@ def main(): [voice_path], cfg_coef=2.0 ) - print("Generating audio...") - # This doesn't do streaming generation, but the model allows it. For now, see Rust - # example. - result = tts_model.generate([entries], [condition_attributes]) - - frames = torch.cat(result.frames, dim=-1) - audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :] - with torch.no_grad(): - audios = tts_model.mimi.decode(audio_tokens) - if args.out == "-": - print("Playing audio...") - play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) + # Stream the audio to the speakers using sounddevice. + import sounddevice as sd + + pcms = queue.Queue() + + def _on_frame(frame): + if (frame != -1).all(): + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.put_nowait(np.clip(pcm[0, 0], -1, 1)) + + def audio_callback(outdata, _a, _b, _c): + try: + pcm_data = pcms.get(block=False) + outdata[:, 0] = pcm_data + except queue.Empty: + outdata[:] = 0 + + with sd.OutputStream( + samplerate=tts_model.mimi.sample_rate, + blocksize=1920, + channels=1, + callback=audio_callback, + ): + tts_model.generate([entries], [condition_attributes], on_frame=_on_frame) + time.sleep(3) + while True: + if pcms.qsize() == 0: + break + time.sleep(1) else: - sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate) - print(f"Audio saved to {args.out}") + result = tts_model.generate([entries], [condition_attributes]) + with torch.no_grad(): + pcms = [] + for frame in result.frames[tts_model.delay_steps :]: + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.append(np.clip(pcm[0, 0], -1, 1)) + pcm = np.concatenate(pcms, axis=-1) + sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate) if __name__ == "__main__": diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb index 9c89892..e62dcd1 100644 --- a/tts_pytorch.ipynb +++ b/tts_pytorch.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b7eed16", + "id": "0", "metadata": {}, "outputs": [], "source": [ @@ -12,8 +12,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "353b9498", + "execution_count": null, + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -31,18 +31,10 @@ }, { "cell_type": "code", - "execution_count": 13, - "id": "8846418a", + "execution_count": null, + "id": "2", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "See https://huggingface.co/datasets/kyutai/tts-voices for available voices.\n" - ] - } - ], + "outputs": [], "source": [ "# Configuration\n", "text = \"Hey there! How are you? I had the craziest day today.\"\n", @@ -52,8 +44,8 @@ }, { "cell_type": "code", - "execution_count": 14, - "id": "b9f022ec", + "execution_count": null, + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -75,18 +67,10 @@ }, { "cell_type": "code", - "execution_count": 15, - "id": "f4f76c73", + "execution_count": null, + "id": "4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generating audio...\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Generating audio...\")\n", "\n", @@ -103,28 +87,10 @@ }, { "cell_type": "code", - "execution_count": 16, - "id": "732e4b4b", + "execution_count": null, + "id": "5", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "display(\n", " Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n", @@ -134,7 +100,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2dbdd275", + "id": "6", "metadata": {}, "outputs": [], "source": []