diff --git a/README.md b/README.md index 927582b..e36cc37 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,17 @@ - - Hugging Face - - - Open In Colab - +# Delayed Streams Modeling: Kyutai STT & TTS - -This repo contains instructions and examples of how to run Kyutai Speech-To-Text models. +This repo contains instructions and examples of how to run +[Kyutai Speech-To-Text](#kyutai-speech-to-text) +and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models. These models are powered by delayed streams modeling (DSM), a flexible formulation for streaming, multimodal sequence-to-sequence learning. -Text-to-speech models based on DSM coming soon! -[Sign up here](https://docs.google.com/forms/d/15sB4zyfuwyXTii4OM74hFGkk4DlDNynJ9xywnaEzE4I/edit) -to be notified when we open-source text-to-speech and [Unmute](https://unmute.sh). - ## Kyutai Speech-To-Text + + Hugging Face + + **More details can be found on the [project page](https://kyutai.org/next/stt).** Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps. @@ -192,9 +188,61 @@ The MLX models can also be used in swift using the [moshi-swift codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been tested to work fine on an iPhone 16 Pro. -## Text-to-Speech +## Kyutai Text-to-Speech -We're in the process of open-sourcing our TTS models. Check back for updates! + + Open In Colab + + +We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use: + +- PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation. +- Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute. +- MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation. + +### PyTorch implementation + + + Open In Colab + + +Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script: + +```bash +# From stdin, plays audio immediately +echo "Hey, how are you?" | python scripts/tts_pytorch.py - - + +# From text file to audio file +python scripts/tts_pytorch.py text_to_say.txt audio_output.wav +``` + +This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip. +If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step +and just prefix the command above with `uvx --with moshi`. + +### Rust server + +Example coming soon. + +### MLX implementation + +[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use +hardware acceleration on Apple silicon. + +Use our example script to run Kyutai TTS on MLX. +The script takes text from stdin or a file and can output to a file or stream the resulting audio. + +```bash +# From stdin, plays audio immediately +echo "Hey, how are you?" | python scripts/tts_mlx.py - - + +# From text file to audio file +python scripts/tts_mlx.py text_to_say.txt audio_output.wav +``` + +This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip. +If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step +and just prefix the command above with `uvx --with moshi-mlx`. ## License diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 0e6e6dc..2fc7052 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -17,16 +17,6 @@ from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel -def audio_to_int16(audio: np.ndarray) -> np.ndarray: - if audio.dtype == np.int16: - return audio - elif audio.dtype == np.float32: - # Multiply by 32767 and not 32768 so that int16 doesn't overflow. - return (np.clip(audio, -1, 1) * 32767).astype(np.int16) - else: - raise TypeError(f"Unsupported audio data type: {audio.dtype}") - - def play_audio(audio: np.ndarray, sample_rate: int): # Requires the Portaudio library which might not be available in all environments. import sounddevice as sd @@ -86,7 +76,8 @@ def main(): ) print("Generating audio...") - # This doesn't do streaming generation, + # This doesn't do streaming generation, but the model allows it. For now, see Rust + # example. result = tts_model.generate([entries], [condition_attributes]) frames = torch.cat(result.frames, dim=-1) diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb new file mode 100644 index 0000000..9c89892 --- /dev/null +++ b/tts_pytorch.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0b7eed16", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "353b9498", + "metadata": {}, + "outputs": [], + "source": [ + "import argparse\n", + "import sys\n", + "\n", + "import numpy as np\n", + "import sphn\n", + "import torch\n", + "from moshi.models.loaders import CheckpointInfo\n", + "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n", + "\n", + "from IPython.display import display, Audio" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8846418a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "See https://huggingface.co/datasets/kyutai/tts-voices for available voices.\n" + ] + } + ], + "source": [ + "# Configuration\n", + "text = \"Hey there! How are you? I had the craziest day today.\"\n", + "voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n", + "print(f\"See https://huggingface.co/datasets/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b9f022ec", + "metadata": {}, + "outputs": [], + "source": [ + "# Set everything up\n", + "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n", + "tts_model = TTSModel.from_checkpoint_info(\n", + " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", + ")\n", + "\n", + "# You could also generate multiple audios at once by passing a list of texts.\n", + "entries = tts_model.prepare_script([text], padding_between=1)\n", + "voice_path = tts_model.get_voice_path(voice)\n", + "# CFG coef goes here because the model was trained with CFG distillation,\n", + "# so it's not _actually_ doing CFG at inference time.\n", + "condition_attributes = tts_model.make_condition_attributes(\n", + " [voice_path], cfg_coef=2.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f4f76c73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating audio...\n" + ] + } + ], + "source": [ + "print(\"Generating audio...\")\n", + "\n", + "# This doesn't do streaming generation,\n", + "result = tts_model.generate([entries], [condition_attributes])\n", + "\n", + "frames = torch.cat(result.frames, dim=-1)\n", + "audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]\n", + "with torch.no_grad():\n", + " audios = tts_model.mimi.decode(audio_tokens)\n", + "\n", + "audio = audios[0].cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "732e4b4b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(\n", + " Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dbdd275", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}