diff --git a/scripts/stt_from_file_with_prompt_pytorch.py b/scripts/stt_from_file_with_prompt_pytorch.py
index 63fe748..6345303 100644
--- a/scripts/stt_from_file_with_prompt_pytorch.py
+++ b/scripts/stt_from_file_with_prompt_pytorch.py
@@ -1,7 +1,6 @@
"""An example script that illustrates how one can prompt Kyutai STT models."""
import argparse
-import dataclasses
import itertools
import math
from collections import deque
diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py
index 2fc7052..515860e 100644
--- a/scripts/tts_pytorch.py
+++ b/scripts/tts_pytorch.py
@@ -11,20 +11,14 @@ import argparse
import sys
import numpy as np
+import queue
import sphn
+import time
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
-def play_audio(audio: np.ndarray, sample_rate: int):
- # Requires the Portaudio library which might not be available in all environments.
- import sounddevice as sd
-
- with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
- sd.play(audio, sample_rate)
-
-
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
@@ -57,6 +51,7 @@ def main():
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
)
+ tts_model.mimi.streaming_forever(batch_size=1)
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
@@ -75,22 +70,45 @@ def main():
[voice_path], cfg_coef=2.0
)
- print("Generating audio...")
- # This doesn't do streaming generation, but the model allows it. For now, see Rust
- # example.
- result = tts_model.generate([entries], [condition_attributes])
-
- frames = torch.cat(result.frames, dim=-1)
- audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
- with torch.no_grad():
- audios = tts_model.mimi.decode(audio_tokens)
-
if args.out == "-":
- print("Playing audio...")
- play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
+ # Stream the audio to the speakers using sounddevice.
+ import sounddevice as sd
+
+ pcms = queue.Queue()
+
+ def _on_frame(frame):
+ if (frame != -1).all():
+ pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
+ pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
+
+ def audio_callback(outdata, _a, _b, _c):
+ try:
+ pcm_data = pcms.get(block=False)
+ outdata[:, 0] = pcm_data
+ except queue.Empty:
+ outdata[:] = 0
+
+ with sd.OutputStream(
+ samplerate=tts_model.mimi.sample_rate,
+ blocksize=1920,
+ channels=1,
+ callback=audio_callback,
+ ):
+ tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
+ time.sleep(3)
+ while True:
+ if pcms.qsize() == 0:
+ break
+ time.sleep(1)
else:
- sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
- print(f"Audio saved to {args.out}")
+ result = tts_model.generate([entries], [condition_attributes])
+ with torch.no_grad():
+ pcms = []
+ for frame in result.frames[tts_model.delay_steps :]:
+ pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
+ pcms.append(np.clip(pcm[0, 0], -1, 1))
+ pcm = np.concatenate(pcms, axis=-1)
+ sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__":
diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb
index 9c89892..e62dcd1 100644
--- a/tts_pytorch.ipynb
+++ b/tts_pytorch.ipynb
@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0b7eed16",
+ "id": "0",
"metadata": {},
"outputs": [],
"source": [
@@ -12,8 +12,8 @@
},
{
"cell_type": "code",
- "execution_count": 4,
- "id": "353b9498",
+ "execution_count": null,
+ "id": "1",
"metadata": {},
"outputs": [],
"source": [
@@ -31,18 +31,10 @@
},
{
"cell_type": "code",
- "execution_count": 13,
- "id": "8846418a",
+ "execution_count": null,
+ "id": "2",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "See https://huggingface.co/datasets/kyutai/tts-voices for available voices.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Configuration\n",
"text = \"Hey there! How are you? I had the craziest day today.\"\n",
@@ -52,8 +44,8 @@
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "b9f022ec",
+ "execution_count": null,
+ "id": "3",
"metadata": {},
"outputs": [],
"source": [
@@ -75,18 +67,10 @@
},
{
"cell_type": "code",
- "execution_count": 15,
- "id": "f4f76c73",
+ "execution_count": null,
+ "id": "4",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Generating audio...\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"print(\"Generating audio...\")\n",
"\n",
@@ -103,28 +87,10 @@
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "732e4b4b",
+ "execution_count": null,
+ "id": "5",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
"display(\n",
" Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
@@ -134,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2dbdd275",
+ "id": "6",
"metadata": {},
"outputs": [],
"source": []