From 07ac744609104460df1818f7355c482c9bca8b6a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?V=C3=A1clav=20Volhejn?=
<8401624+vvolhejn@users.noreply.github.com>
Date: Wed, 2 Jul 2025 17:51:27 +0200
Subject: [PATCH] Add PyTorch notebook and documentation (#29)
* Add example for PyTorch implementation
* Document PyTorch and MLX examples
* Reorganize for TTS
* Remove waitlist signup CTA
---
README.md | 76 +++++++++++++++----
scripts/tts_pytorch.py | 13 +---
tts_pytorch.ipynb | 164 +++++++++++++++++++++++++++++++++++++++++
3 files changed, 228 insertions(+), 25 deletions(-)
create mode 100644 tts_pytorch.ipynb
diff --git a/README.md b/README.md
index 927582b..e36cc37 100644
--- a/README.md
+++ b/README.md
@@ -1,21 +1,17 @@
-
-
-
-
-
-
+# Delayed Streams Modeling: Kyutai STT & TTS
-
-This repo contains instructions and examples of how to run Kyutai Speech-To-Text models.
+This repo contains instructions and examples of how to run
+[Kyutai Speech-To-Text](#kyutai-speech-to-text)
+and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models.
These models are powered by delayed streams modeling (DSM),
a flexible formulation for streaming, multimodal sequence-to-sequence learning.
-Text-to-speech models based on DSM coming soon!
-[Sign up here](https://docs.google.com/forms/d/15sB4zyfuwyXTii4OM74hFGkk4DlDNynJ9xywnaEzE4I/edit)
-to be notified when we open-source text-to-speech and [Unmute](https://unmute.sh).
-
## Kyutai Speech-To-Text
+
+
+
+
**More details can be found on the [project page](https://kyutai.org/next/stt).**
Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps.
@@ -192,9 +188,61 @@ The MLX models can also be used in swift using the [moshi-swift
codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been
tested to work fine on an iPhone 16 Pro.
-## Text-to-Speech
+## Kyutai Text-to-Speech
-We're in the process of open-sourcing our TTS models. Check back for updates!
+
+
+
+
+We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use:
+
+- PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
+- Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute.
+- MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
+
+### PyTorch implementation
+
+
+
+
+
+Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script:
+
+```bash
+# From stdin, plays audio immediately
+echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
+
+# From text file to audio file
+python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
+```
+
+This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
+If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
+and just prefix the command above with `uvx --with moshi`.
+
+### Rust server
+
+Example coming soon.
+
+### MLX implementation
+
+[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
+hardware acceleration on Apple silicon.
+
+Use our example script to run Kyutai TTS on MLX.
+The script takes text from stdin or a file and can output to a file or stream the resulting audio.
+
+```bash
+# From stdin, plays audio immediately
+echo "Hey, how are you?" | python scripts/tts_mlx.py - -
+
+# From text file to audio file
+python scripts/tts_mlx.py text_to_say.txt audio_output.wav
+```
+
+This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip.
+If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
+and just prefix the command above with `uvx --with moshi-mlx`.
## License
diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py
index 0e6e6dc..2fc7052 100644
--- a/scripts/tts_pytorch.py
+++ b/scripts/tts_pytorch.py
@@ -17,16 +17,6 @@ from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
-def audio_to_int16(audio: np.ndarray) -> np.ndarray:
- if audio.dtype == np.int16:
- return audio
- elif audio.dtype == np.float32:
- # Multiply by 32767 and not 32768 so that int16 doesn't overflow.
- return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
- else:
- raise TypeError(f"Unsupported audio data type: {audio.dtype}")
-
-
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
@@ -86,7 +76,8 @@ def main():
)
print("Generating audio...")
- # This doesn't do streaming generation,
+ # This doesn't do streaming generation, but the model allows it. For now, see Rust
+ # example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb
new file mode 100644
index 0000000..9c89892
--- /dev/null
+++ b/tts_pytorch.ipynb
@@ -0,0 +1,164 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0b7eed16",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "353b9498",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import argparse\n",
+ "import sys\n",
+ "\n",
+ "import numpy as np\n",
+ "import sphn\n",
+ "import torch\n",
+ "from moshi.models.loaders import CheckpointInfo\n",
+ "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
+ "\n",
+ "from IPython.display import display, Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "8846418a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "See https://huggingface.co/datasets/kyutai/tts-voices for available voices.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Configuration\n",
+ "text = \"Hey there! How are you? I had the craziest day today.\"\n",
+ "voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n",
+ "print(f\"See https://huggingface.co/datasets/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "b9f022ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set everything up\n",
+ "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
+ "tts_model = TTSModel.from_checkpoint_info(\n",
+ " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
+ ")\n",
+ "\n",
+ "# You could also generate multiple audios at once by passing a list of texts.\n",
+ "entries = tts_model.prepare_script([text], padding_between=1)\n",
+ "voice_path = tts_model.get_voice_path(voice)\n",
+ "# CFG coef goes here because the model was trained with CFG distillation,\n",
+ "# so it's not _actually_ doing CFG at inference time.\n",
+ "condition_attributes = tts_model.make_condition_attributes(\n",
+ " [voice_path], cfg_coef=2.0\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "f4f76c73",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generating audio...\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Generating audio...\")\n",
+ "\n",
+ "# This doesn't do streaming generation,\n",
+ "result = tts_model.generate([entries], [condition_attributes])\n",
+ "\n",
+ "frames = torch.cat(result.frames, dim=-1)\n",
+ "audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]\n",
+ "with torch.no_grad():\n",
+ " audios = tts_model.mimi.decode(audio_tokens)\n",
+ "\n",
+ "audio = audios[0].cpu().numpy()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "732e4b4b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "display(\n",
+ " Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2dbdd275",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}