Compare commits

..

43 Commits

Author SHA1 Message Date
96042c9983 install.sh hinzugefügt
Some checks failed
precommit / Run precommit (push) Has been cancelled
2025-08-12 18:03:43 +00:00
7e92c6f28a app/scripts/tts_runner.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:59:20 +00:00
8e51c6eab9 app/dependency_check.py aktualisiert
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:58:17 +00:00
34e58b8055 dependency_check.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:57:33 +00:00
f23d1be027 app/api_server.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:56:32 +00:00
Laurent Mazare
cf97f8d863
Workaround for the mlx kv-cache bug. (#108) 2025-08-04 16:37:00 +02:00
Laurent Mazare
09468c239a
Print the duration of the audio generated so far. (#107) 2025-08-04 09:24:31 +02:00
Laurent Mazare
07729ed47e
Use the proper repos when vad is on. (#103) 2025-08-01 15:55:49 +02:00
Laurent Mazare
af2283de3f
Use a streaming input in the rust example. (#102)
* Use a streaming input in the rust example.

* Formatting.

* Another formatting tweak.
2025-07-31 17:41:57 +02:00
Laurent Mazare
7dc926d50c
Allow for using local voices in the pytorch examples. (#100) 2025-07-31 12:48:05 +02:00
Laurent Mazare
ab8e8c59b7
Bump the version numbers. (#91) 2025-07-19 15:57:53 +02:00
laurent
5f17114618 More faq. 2025-07-18 08:43:10 +02:00
laurent
405a82ba3f FAQ tweaks. 2025-07-18 08:37:51 +02:00
Laurent Mazare
3b584b100c
Sketch a FAQ and add some issue templates. (#88) 2025-07-18 08:31:31 +02:00
Laurent Mazare
a98eb94ade
Add a streaming example for the mlx tts. (#85)
* Add a streaming example for the mlx tts.

* Fix the CI.

* Formatting fix.

* Yet another CI fix.
2025-07-16 22:35:44 +02:00
Laurent Mazare
a2f031deb5
Fix the pytorch tts streaming example. (#84)
* Fix the pytorch tts streaming example.

* Edit the readme too.
2025-07-16 21:07:02 +02:00
Laurent Mazare
66a33c989f
Add a TTS streaming example. (#83)
* Add a TTS streaming example.

* Get the streaming example to work.
2025-07-16 20:54:13 +02:00
laurent
89a2ced839 Bump the moshi version. 2025-07-16 16:02:28 +02:00
Laurent
baf0c75bba VAD support in the mlx-stt example that uses the microphone. 2025-07-08 16:08:32 +02:00
Laurent Mazare
952319de90
Add a MLX STT example that uses VAD. (#70)
* Add a MLX STT example that uses VAD.

* VAD support.

* More MLX VAD example.

* Use the latest moshi-mlx.
2025-07-08 16:04:50 +02:00
laurent
6d3bb6b1f1 Avoid the config override for the extra-heads. 2025-07-08 15:39:17 +02:00
Laurent Mazare
12dbe36b0b
Add some VAD to the pytorch speech-to-text example. (#68) 2025-07-08 11:30:34 +02:00
Václav Volhejn
cafac63222
Run pre-commit correctly in CI (#66)
* fix and break

* Remove intentional error
2025-07-08 10:11:52 +02:00
Václav Volhejn
7336d7a3da
Fix instructions on how to install the Rust server (#65)
* Fix instructions on Rust server installation

* Plug Unmute
2025-07-07 18:00:23 +02:00
Laurent Mazare
70500c620e
Add a device argument to the tts pytorch script. (#62) 2025-07-07 08:36:47 +02:00
Chenghao Mou
f8e97aa4f3
fix minor issues with readme commands (#55) 2025-07-07 08:18:05 +02:00
laurent
91a4d120cb Use moshi 0.2.8. 2025-07-07 08:12:16 +02:00
Laurent
bfc200f6ee Use bfloat16 rather than half by default. 2025-07-05 23:02:58 +02:00
laurent
f9739881e6 Typo. 2025-07-03 19:02:42 +02:00
Alexandre Défossez
99599fa408
Update README.md 2025-07-03 16:15:01 +02:00
Pierre-Hugues HUSSON
3a4165a84f
Fix stt_from_file_pytorch (#39)
1. argparse declares in_file, but code reads file
2. text_tokens.numpy().tolist() is a list of list of list of int
instead of the supported list of list of int.
this is a debugging print just drop it

Co-authored-by: Pierre-Hugues Husson <phhusson@freebox.fr>
2025-07-03 15:26:34 +02:00
Alexandre Défossez
e9bac066ea
Update README.md 2025-07-03 15:09:41 +02:00
Alexandre Défossez
eae5e17975
Some updates to the colab and script (#38)
* changing streaming to be robust to repeated generation

* some changes

* plop

* plop

* plop

* plop
2025-07-03 15:06:37 +02:00
Václav Volhejn
c1d248abba
Fix text tokenizer path (#36) 2025-07-03 14:27:06 +02:00
Václav Volhejn
c6f262346f
Don't install moshi from Git (#37)
* Don't install moshi from Git

* Remove commented-out invalid message send in websocket_client
2025-07-03 13:37:38 +02:00
laurent
3573ee90af Oops. 2025-07-03 13:08:00 +02:00
laurent
25574aa104 Fixes for the notebook. 2025-07-03 13:05:00 +02:00
laurent
1cd9529f65 Json fix. 2025-07-03 12:57:22 +02:00
laurent
0ee2354176 Chunk decoding in the pth notebook. 2025-07-03 12:56:00 +02:00
laurent
dc8bffabe0 Remove the dataset bit. 2025-07-03 12:48:04 +02:00
Laurent Mazare
5f8e924176
Streaming output for the pytorch tts example. (#33)
* Streaming output for the pytorch tts example.

* Run the pre-commit hooks.
2025-07-03 11:05:06 +02:00
Laurent Mazare
d3bed09f9a
Pin the moshi_mlx version. (#35) 2025-07-03 09:53:53 +02:00
Václav Volhejn
ef52b8ef0f
Add Rust server usage example (#32)
* Run Ruff on tts_mlx.py

* Add tts_rust_server.py example

* Remove unused HF repo arguments and reset audio output data in TTS server script
2025-07-03 09:47:50 +02:00
22 changed files with 1823 additions and 131 deletions

83
.github/ISSUE_TEMPLATE/bug.yml vendored Normal file
View File

@ -0,0 +1,83 @@
name: Bug Report
description: You found a bug.
labels: ["bug", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
- type: dropdown
id: backend
attributes:
label: Backend impacted
description: Which backend is concerned with your bug report?
options:
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
description: What is your operating system?
options:
- Linux
- Mac OS X
- Windows (unsupported)
default: 0
validations:
required: true
- type: dropdown
id: hardware
attributes:
label: Hardware
description: What hardware are you using?
options:
- CPU
- GPU with CUDA
- Metal with MLX
default: 0
validations:
required: true
- type: textarea
id: description
attributes:
label: Description
description: Provide a detailed description of your bug.
placeholder:
value:
validations:
required: true
- type: textarea
id: more_info
attributes:
label: Extra information
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value:
validations:
required: true
- type: textarea
id: env
attributes:
label: Environment
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value: |
Fill in the following information on your system.
- Operating system version:
If the backend impacted is PyTorch:
- Python version:
- PyTorch version:
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
- GPU model and memory:
If the backend is MLX:
- Mac model:
validations:
required: true

40
.github/ISSUE_TEMPLATE/question.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Question
description: You have a question about the codebase, the paper, or the implementation.
labels: ["question", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
- type: checkboxes
id: terms
attributes:
label: Due diligence
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
options:
- label: I have done my due diligence in trying to find the answer myself.
required: true
- type: dropdown
id: backend
attributes:
label: Topic
description: What is your question about?
options:
- The paper
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: textarea
id: question
attributes:
label: Question
description: What is your question?
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
value:
validations:
required: true

View File

@ -19,7 +19,7 @@ runs:
. env/bin/activate . env/bin/activate
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install moshi==0.2.6 pip install moshi==0.2.7
pip install pre-commit pip install pre-commit
- name: Setup env - name: Setup env
shell: bash shell: bash

View File

@ -13,5 +13,5 @@ jobs:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: ./.github/actions/moshi_build - uses: ./.github/actions/moshi_build
- run: | - run: |
. env/bin/activate source env/bin/activate
bash .git/hooks/pre-commit pre-commit run --all-files

1
.gitignore vendored
View File

@ -192,3 +192,4 @@ cython_debug/
# refer to https://docs.cursor.com/context/ignore-files # refer to https://docs.cursor.com/context/ignore-files
.cursorignore .cursorignore
.cursorindexingignore .cursorindexingignore
out*.wav

56
FAQ.md Normal file
View File

@ -0,0 +1,56 @@
# FAQ
Here is the answer to a number of frequently asked questions.
### Torch compilation issues
With some PyTorch/triton versions, one might encounter compilation errors
like the following:
```
Traceback (most recent call last):
...
File "site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1153, in make_launcher
"launch_enter_hook": binary.__class__.launch_enter_hook,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
```
If that's the case, you can disable torch compilation by setting the following
environment variable.
```bash
export NO_TORCH_COMPILE=1
```
### Issues installing the sentencepiece dependency
On some linux distributions (arch) or on macos, the local version of cmake can
be too recent for the sentencepiece dependency.
```
CMake Error at CMakeLists.txt:15 (cmake_minimum_required):
Compatibility with CMake < 3.5 has been removed from CMake.
```
You can either downgrade your cmake version, e.g. 3.31.0 on arch works or try
setting `CMAKE_POLICY_VERSION_MINIMUM=3.5`.
If you run into some errors when compiling the sentencepiece rust bindings,
these could also be due to gcc being too recent, e.g. gcc 15. You can get
around this by using gcc-13, e.g. by setting the following after installing
the proper gcc packages.
```bash
export CMAKE_C_COMPILER=/usr/bin/gcc-13
export CMAKE_CXX_COMPILER=/usr/bin/g++-13
CC=gcc-13 CXX=g++-13 cargo build --release
```
Alternatively you can set `CXXFLAGS="-include cstdint"`, see this
[issue](https://github.com/google/sentencepiece/issues/1108).
### Will you release training code?
Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
This code has not been adapted to the Speech-To-Text and Text-To-Speech models
yet, but it should be a good starting point.

View File

@ -5,6 +5,10 @@ This repo contains instructions and examples of how to run
and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models. and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models.
These models are powered by delayed streams modeling (DSM), These models are powered by delayed streams modeling (DSM),
a flexible formulation for streaming, multimodal sequence-to-sequence learning. a flexible formulation for streaming, multimodal sequence-to-sequence learning.
See also [Unmute](https://github.com/kyutai-labs/unmute), an voice AI system built using Kyutai STT and Kyutai TTS.
But wait, what is "Delayed Streams Modeling"? It is a technique for solving many streaming X-to-Y tasks (with X, Y in `{speech, text}`)
that formalize the approach we had with Moshi and Hibiki. A pre-print paper is coming soon!
## Kyutai Speech-To-Text ## Kyutai Speech-To-Text
@ -77,7 +81,7 @@ Additionally, we provide two scripts that highlight different usage scenarios. T
uv run \ uv run \
scripts/stt_from_file_pytorch.py \ scripts/stt_from_file_pytorch.py \
--hf-repo kyutai/stt-2.6b-en \ --hf-repo kyutai/stt-2.6b-en \
--file audio/bria.mp3 audio/bria.mp3
``` ```
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics: The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
@ -113,7 +117,7 @@ However, please bear in mind that is an experimental feature and its behavior is
</a> </a>
The Rust implementation provides a server that can process multiple streaming The Rust implementation provides a server that can process multiple streaming
queries in parallel. Dependening on the amount of memory on your GPU, you may queries in parallel. Depending on the amount of memory on your GPU, you may
have to adjust the batch size from the config file. For a L40S GPU, a batch size have to adjust the batch size from the config file. For a L40S GPU, a batch size
of 64 works well and requests can be processed at 3x real-time speed. of 64 works well and requests can be processed at 3x real-time speed.
@ -161,7 +165,7 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r
This can be used as follows: This can be used as follows:
```bash ```bash
cd stt-rs cd stt-rs
cargo run --features cuda -r -- audio/bria.mp3 cargo run --features cuda -r -- ../audio/bria.mp3
``` ```
You can get the timestamps by adding the `--timestamps` flag, and see the output You can get the timestamps by adding the `--timestamps` flag, and see the output
of the semantic VAD by adding the `--vad` flag. of the semantic VAD by adding the `--vad` flag.
@ -204,10 +208,12 @@ tested to work fine on an iPhone 16 Pro.
<a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;"> <a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/> <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/>
</a> </a>
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a> </a>
**More details can be found on the [project page](https://kyutai.org/next/tts).**
We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use: We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use:
- PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation. - PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
@ -231,6 +237,14 @@ echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
python scripts/tts_pytorch.py text_to_say.txt audio_output.wav python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
``` ```
The `tts_pytorch.py` script waits for all the text to be available before
starting the audio generation. A fully streaming implementation is available in
the `tts_pytorch_streaming.py` script, which can be used as follows:
```bash
echo "Hey, how are you?" | python scripts/tts_pytorch_streaming.py audio_output.wav
```
This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip. This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
and just prefix the command above with `uvx --with moshi`. and just prefix the command above with `uvx --with moshi`.
@ -239,7 +253,31 @@ and just prefix the command above with `uvx --with moshi`.
<details> <details>
<summary>Rust server</summary> <summary>Rust server</summary>
Example coming soon.
The Rust implementation provides a server that can process multiple streaming
queries in parallel.
Installing the Rust server is a bit tricky because it uses our Python implementation under the hood,
which also requires installing the Python dependencies.
Use the [start_tts.sh](https://github.com/kyutai-labs/unmute/blob/main/dockerless/start_tts.sh) script to properly install the Rust server.
If you already installed the `moshi-server` crate before and it's not working, you might need to force a reinstall by running `cargo uninstall moshi-server` first.
Feel free to open an issue if the installation is still broken.
Once installed, the server can be started via the following command using the config file
from this repository.
```bash
moshi-server worker --config configs/config-tts.toml
```
Once the server has started you can connect to it using our script as follows:
```bash
# From stdin, plays audio immediately
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
# From text file to audio file
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
```
</details> </details>
<details> <details>
@ -267,6 +305,10 @@ If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the install
and just prefix the command above with `uvx --with moshi-mlx`. and just prefix the command above with `uvx --with moshi-mlx`.
</details> </details>
## FAQ
Checkout the [Frequently Asked Questions](FAQ.md) section before opening an issue.
## License ## License
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend. The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.

298
app/api_server.py Normal file
View File

@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
OpenAI-Compatible Kyutai TTS API Server with Model Caching
Improved version that loads the model once and keeps it in memory
"""
import os
import io
import time
import asyncio
import subprocess
from pathlib import Path
from typing import Optional, Literal
import logging
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables - loaded once at startup
tts_model = None
device = None
sample_rate = None
class SpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
app = FastAPI(
title="OpenAI-Compatible TTS API (Cached)",
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = Path("/app/api_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def load_tts_model():
"""Load TTS model once at startup and keep in memory"""
global tts_model, device, sample_rate
if tts_model is not None:
logger.info("TTS model already loaded")
return
try:
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
# Import Kyutai TTS modules
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load the TTS model
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info,
n_q=32,
temp=0.6,
device=device
)
# Get sample rate
sample_rate = tts_model.mimi.sample_rate
logger.info(f"✅ TTS model loaded successfully!")
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
logger.info(f" Device: {device}")
logger.info(f" Sample Rate: {sample_rate}")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
"""Generate audio using cached TTS model"""
global tts_model, device, sample_rate
if tts_model is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
try:
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare the script (text input)
entries = tts_model.prepare_script([text], padding_between=1)
# Voice mapping for OpenAI compatibility
voice_mapping = {
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
}
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
try:
voice_path = tts_model.get_voice_path(selected_voice)
except:
# Fallback to default if voice not found
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
# Prepare condition attributes
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
# Generate audio
pcms = []
def on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
# Concatenate all audio frames
if pcms:
import numpy as np
audio = np.concatenate(pcms, axis=-1)
# Apply speed adjustment if needed
if speed != 1.0:
# Simple speed adjustment by resampling
from scipy import signal
audio_length = len(audio)
new_length = int(audio_length / speed)
audio = signal.resample(audio, new_length)
# Convert to bytes
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
audio_bytes.seek(0)
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
return audio_bytes.read()
else:
raise Exception("No audio frames generated")
except Exception as e:
logger.error(f"❌ TTS generation error: {e}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
"""Convert WAV audio to requested format using ffmpeg"""
try:
if output_format == "wav":
return audio_wav_bytes
# Use ffmpeg to convert
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
result = subprocess.run(
cmd,
input=audio_wav_bytes,
capture_output=True,
check=True
)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"Audio conversion failed: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
OpenAI-compatible audio speech endpoint
Uses cached TTS model for fast generation
"""
try:
start_time = time.time()
# Generate audio with cached model
audio_wav_bytes = generate_audio_fast(
text=request.input,
voice=request.voice,
speed=request.speed
)
# Convert to requested format
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
generation_time = time.time() - start_time
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
# Set appropriate content type
content_types = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm"
}
return Response(
content=audio_data,
media_type=content_types.get(request.response_format, "audio/wav"),
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Generation-Time": str(generation_time)
}
)
except Exception as e:
logger.error(f"Speech generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "tts-1",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1",
"parent": None
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1-hd",
"parent": None
}
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint with model status"""
model_loaded = tts_model is not None
return {
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": str(device) if device else None,
"service": "kyutai-tts-openai-compatible-cached"
}
@app.get("/reload-model")
async def reload_model():
"""Reload the TTS model (admin endpoint)"""
global tts_model
try:
tts_model = None
load_tts_model()
return {"status": "success", "message": "Model reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting TTS API server with model caching...")
load_tts_model()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

67
app/dependency_check.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Check if all Kyutai TTS dependencies are properly installed
"""
import sys
def check_dependencies():
print("🔍 Checking Kyutai TTS Dependencies")
print("=" * 40)
dependencies = [
"torch",
"numpy",
"einops",
"transformers",
"accelerate",
"soundfile",
"librosa",
"huggingface_hub",
"moshi",
"sphn"
]
missing = []
installed = []
for dep in dependencies:
try:
__import__(dep)
installed.append(dep)
print(f"{dep}")
except ImportError as e:
missing.append((dep, str(e)))
print(f"{dep}: {e}")
print(f"\n📊 Summary:")
print(f"✓ Installed: {len(installed)}")
print(f"✗ Missing: {len(missing)}")
if missing:
print(f"\n🔧 To fix missing dependencies:")
for dep, error in missing:
print(f"pip install {dep}")
print(f"\n🧪 Testing Kyutai TTS imports:")
try:
from moshi.models.loaders import CheckpointInfo
print("✓ CheckpointInfo import successful")
except Exception as e:
print(f"✗ CheckpointInfo import failed: {e}")
try:
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
print("✓ TTSModel imports successful")
except Exception as e:
print(f"✗ TTSModel imports failed: {e}")
return len(missing) == 0
if __name__ == "__main__":
success = check_dependencies()
if success:
print("\n🎉 All dependencies are installed correctly!")
else:
print("\n❌ Some dependencies are missing. Please install them first.")
sys.exit(1)

59
app/scripts/tts_runner.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
Kyutai TTS PyTorch Runner
Dockerized implementation for text-to-speech generation
"""
import sys
import os
import argparse
import torch
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
parser.add_argument('input_file', help='Input text file or "-" for stdin')
parser.add_argument('output_file', help='Output audio file')
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
args = parser.parse_args()
print(f"Using device: {args.device}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Handle stdin input
if args.input_file == '-':
# Read from stdin and create temporary file
text = sys.stdin.read().strip()
temp_file = '/tmp/temp_input.txt'
with open(temp_file, 'w') as f:
f.write(text)
input_file = temp_file
else:
input_file = args.input_file
# Check if the original TTS script exists
tts_script = Path('/app/scripts/tts_pytorch.py')
if tts_script.exists():
print("Using original TTS script from Kyutai repository")
import subprocess
cmd = ['python', str(tts_script), input_file, args.output_file]
subprocess.run(cmd, check=True)
else:
print("Using moshi package for TTS generation")
import subprocess
cmd = [
'python', '-m', 'moshi.run_inference',
'--hf-repo', args.model,
input_file,
args.output_file
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error: {result.stderr}")
sys.exit(1)
print(f"Audio generated: {args.output_file}")
if __name__ == '__main__':
main()
EOF

20
configs/config-tts.toml Normal file
View File

@ -0,0 +1,20 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["public_token"]
[modules.tts_py]
type = "Py"
path = "/api/tts_streaming"
text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
batch_size = 8 # Adjust to your GPU memory capacity
text_bos_token = 1
[modules.tts_py.py]
log_folder = "$HOME/tmp/moshi-server-logs"
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
default_voice = "unmute-prod-website/default_voice.wav"
cfg_coef = 2.0
cfg_is_no_text = true
padding_between = 1
n_q = 24

78
install.sh Normal file
View File

@ -0,0 +1,78 @@
# Set environment variables
export DEBIAN_FRONTEND=noninteractive
export PYTHONUNBUFFERED=1
export CUDA_VISIBLE_DEVICES=0
# Install system dependencies
apt-get update && apt-get install -y \
wget \
curl \
git \
build-essential \
libsndfile1 \
ffmpeg \
sox \
alsa-utils \
pulseaudio \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Create virtual environment
apt install python3.12-venv python3.12-dev
python3.12 -m venv ~/venv-tts-kyutai
source ~/venv-tts-kyutai/bin/activate
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Install PyTorch with CUDA support for Python 3.12
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install core dependencies
pip install --no-cache-dir \
numpy \
scipy \
librosa \
soundfile \
huggingface_hub \
einops \
transformers \
accelerate
# Install API dependencies
pip install --no-cache-dir \
fastapi \
uvicorn[standard] \
python-multipart \
pydantic
# Install moshi package with all dependencies (following Colab notebook)
pip install --no-cache-dir 'sphn<0.2'
pip install --no-cache-dir "moshi==0.2.8"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Start TTS-Server
python /app/api_server.py

View File

@ -0,0 +1,100 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sentencepiece",
# "sounddevice",
# "sphn",
# ]
# ///
import argparse
import json
import mlx.core as mx
import mlx.nn as nn
import sentencepiece
import sphn
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
audio, _ = sphn.read(args.in_file, sample_rate=24000)
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj:
lm_config = json.load(fobj)
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
moshi_name = lm_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
lm_config = models.LmConfig.from_config_dict(lm_config)
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
if moshi_weights.endswith(".q4.safetensors"):
nn.quantize(model, bits=4, group_size=32)
elif moshi_weights.endswith(".q8.safetensors"):
nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}")
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {text_tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer) # type: ignore
print(f"loading the audio tokenizer {mimi_weights}")
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
print("warming up the model")
model.warmup()
gen = models.LmGen(
model=model,
max_steps=args.max_steps,
text_sampler=utils.Sampler(top_k=25, temp=0),
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
check=False,
)
print(f"starting inference {audio.shape}")
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
last_print_was_vad = False
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
block = audio[:, None, start_idx : start_idx + 1920]
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
_text = None
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
print()

View File

@ -4,7 +4,7 @@
# "julius", # "julius",
# "librosa", # "librosa",
# "soundfile", # "soundfile",
# "moshi", # "moshi==0.2.11",
# ] # ]
# /// # ///
@ -20,8 +20,8 @@ import math
import julius import julius
import moshi.models import moshi.models
import sphn import sphn
import time
import torch import torch
import tqdm
@dataclasses.dataclass @dataclasses.dataclass
@ -128,6 +128,9 @@ def tokens_to_timestamped_text(
def main(args): def main(args):
if args.vad and args.hf_repo is None:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
info = moshi.models.loaders.CheckpointInfo.from_hf_repo( info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo, args.hf_repo,
moshi_weights=args.moshi_weight, moshi_weights=args.moshi_weight,
@ -150,7 +153,7 @@ def main(args):
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3) padding_token_id = info.raw_config.get("text_padding_token_id", 3)
audio, input_sample_rate = sphn.read(args.file) audio, input_sample_rate = sphn.read(args.in_file)
audio = torch.from_numpy(audio).to(args.device) audio = torch.from_numpy(audio).to(args.device)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0: if audio.shape[-1] % mimi.frame_size != 0:
@ -171,16 +174,35 @@ def main(args):
itertools.repeat(silence_chunk, n_suffix_chunks), itertools.repeat(silence_chunk, n_suffix_chunks),
) )
start_time = time.time()
nchunks = 0
last_print_was_vad = False
with mimi.streaming(1), lm_gen.streaming(1): with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks): for audio_chunk in chunks:
nchunks += 1
audio_tokens = mimi.encode(audio_chunk) audio_tokens = mimi.encode(audio_chunk)
if args.vad:
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_tokens = lm_gen.step(audio_tokens) text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None: text_token = text_tokens[0, 0, 0].cpu().item()
if text_token not in (0, 3):
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
text_tokens_accum.append(text_tokens) text_tokens_accum.append(text_tokens)
print(tokenizer.decode(text_tokens.numpy().tolist()))
utterance_tokens = torch.concat(text_tokens_accum, dim=-1) utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
dt = time.time() - start_time
print(
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
)
timed_text = tokens_to_timestamped_text( timed_text = tokens_to_timestamped_text(
utterance_tokens, utterance_tokens,
tokenizer, tokenizer,
@ -211,6 +233,9 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None "--config-path", type=str, help="Path to a local config file.", default=None
) )
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,

View File

@ -1,7 +1,6 @@
"""An example script that illustrates how one can prompt Kyutai STT models.""" """An example script that illustrates how one can prompt Kyutai STT models."""
import argparse import argparse
import dataclasses
import itertools import itertools
import math import math
from collections import deque from collections import deque

View File

@ -2,7 +2,7 @@
# requires-python = ">=3.12" # requires-python = ">=3.12"
# dependencies = [ # dependencies = [
# "huggingface_hub", # "huggingface_hub",
# "moshi_mlx", # "moshi_mlx==0.2.12",
# "numpy", # "numpy",
# "rustymimi", # "rustymimi",
# "sentencepiece", # "sentencepiece",
@ -25,9 +25,17 @@ from moshi_mlx import models, utils
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--max-steps", default=4096) parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") parser.add_argument("--hf-repo")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args() args = parser.parse_args()
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json") lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj: with open(lm_config, "r") as fobj:
lm_config = json.load(fobj) lm_config = json.load(fobj)
@ -45,6 +53,9 @@ if __name__ == "__main__":
nn.quantize(model, bits=8, group_size=64) nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}") print(f"loading model weights from {moshi_weights}")
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True) model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {tokenizer}") print(f"loading the text tokenizer from {tokenizer}")
@ -71,6 +82,7 @@ if __name__ == "__main__":
block_queue.put(indata.copy()) block_queue.put(indata.copy())
print("recording audio from microphone, speak to get your words transcribed") print("recording audio from microphone, speak to get your words transcribed")
last_print_was_vad = False
with sd.InputStream( with sd.InputStream(
channels=1, channels=1,
dtype="float32", dtype="float32",
@ -85,6 +97,14 @@ if __name__ == "__main__":
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[ other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks :, :, :other_codebooks
] ]
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0]) text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item() text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens() audio_tokens = gen.last_audio_tokens()
@ -93,3 +113,4 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ") _text = _text.replace("", " ")
print(_text, end="", flush=True) print(_text, end="", flush=True)
last_print_was_vad = False

View File

@ -2,7 +2,7 @@
# requires-python = ">=3.12" # requires-python = ">=3.12"
# dependencies = [ # dependencies = [
# "huggingface_hub", # "huggingface_hub",
# "moshi_mlx==0.2.9", # "moshi_mlx==0.2.12",
# "numpy", # "numpy",
# "sounddevice", # "sounddevice",
# ] # ]
@ -14,19 +14,20 @@ import queue
import sys import sys
import time import time
import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
import sentencepiece import sentencepiece
import sphn
import time
import sounddevice as sd import sounddevice as sd
import sphn
from moshi_mlx.client_utils import make_log
from moshi_mlx import models from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get from moshi_mlx.utils.loaders import hf_get
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str): def log(level: str, msg: str):
@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main(): def main():
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi') parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin") parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio") parser.add_argument(
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO, "out", type=str, help="Output file to generate, use - for playing the audio"
help="HF repo in which to look for the pretrained models.") )
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO, parser.add_argument(
help="HF repo in which to look for pre-computed voice embeddings.") "--hf-repo",
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav") type=str,
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.") default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(299792458) mx.random.seed(299792458)
@ -58,6 +76,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)
@ -96,7 +117,7 @@ def main():
if tts_model.valid_cfg_conditionings: if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation. # Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1. tts_model.cfg_coef = 1.0
cfg_is_no_text = False cfg_is_no_text = False
cfg_is_no_prefix = False cfg_is_no_prefix = False
else: else:
@ -118,9 +139,12 @@ def main():
voices = [tts_model.get_voice_path(args.voice)] voices = [tts_model.get_voice_path(args.voice)]
else: else:
voices = [] voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)] all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue() wav_frames = queue.Queue()
def _on_frame(frame): def _on_frame(frame):
if (frame == -1).any(): if (frame == -1).any():
return return
@ -146,16 +170,20 @@ def main():
return result return result
if args.out == "-": if args.out == "-":
def audio_callback(outdata, _a, _b, _c): def audio_callback(outdata, _a, _b, _c):
try: try:
pcm_data = wav_frames.get(block=False) pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data outdata[:, 0] = pcm_data
except queue.Empty: except queue.Empty:
outdata[:] = 0 outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate,
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920, blocksize=1920,
channels=1, channels=1,
callback=audio_callback): callback=audio_callback,
):
run() run()
time.sleep(3) time.sleep(3)
while True: while True:

View File

@ -0,0 +1,317 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import json
import queue
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sounddevice as sd
import sphn
import typing as tp
from moshi_mlx import models
from moshi_mlx.models.generate import LmGen
from moshi_mlx.client_utils import make_log
from moshi_mlx.modules.conditioner import (
ConditionAttributes,
ConditionTensor,
dropout_all_conditions,
)
from moshi_mlx.utils.sampling import Sampler
from moshi_mlx.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
script_to_entries,
)
from moshi_mlx.utils.loaders import hf_get
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[mx.array], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
self.ct = None
self.cross_attention_src = None
for _attr in attributes:
for _key, _value in _attr.text.items():
_ct = tts_model.lm.condition_provider.condition_tensor(_key, _value)
if self.ct is None:
self.ct = _ct
else:
self.ct = ConditionTensor(self.ct.tensor + _ct.tensor)
for _key, _value in _attr.tensor.items():
_conditioner = tts_model.lm.condition_provider.conditioners[_key]
_ca_src = _conditioner.condition(_value)
if self.cross_attention_src is None:
self.cross_attention_src = _ca_src
else:
raise ValueError("multiple cross-attention conditioners")
def _on_audio_hook(audio_tokens):
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[0]):
delay = delays[q]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = mx.array(out_tokens, dtype=mx.int64)
self.lm_gen = LmGen(
tts_model.lm,
max_steps=tts_model.max_gen_length,
text_sampler=Sampler(temp=tts_model.temp),
audio_sampler=Sampler(temp=tts_model.temp),
cfg_coef=tts_model.cfg_coef,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
# TODO(laurent):
# cfg_is_masked_until=cfg_is_masked_until,
# cfg_is_no_text=cfg_is_no_text,
)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = (
mx.ones((1, missing), dtype=mx.int64)
* self.tts_model.machine.token_ids.zero
)
self.lm_gen.step(
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
)
frame = self.lm_gen.last_audio_tokens()
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
def log(level: str, msg: str):
print(make_log(level, msg))
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
log("info", "retrieving checkpoints")
raw_config = hf_get("config.json", args.hf_repo)
with open(hf_get(raw_config), "r") as fobj:
raw_config = json.load(fobj)
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
moshi_name = raw_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
log("info", f"loading model weights from {moshi_weights}")
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
if args.quantize is not None:
log("info", f"quantizing model to {args.quantize} bits")
nn.quantize(model.depformer, bits=args.quantize)
for layer in model.transformer.layers:
nn.quantize(layer.self_attn, bits=args.quantize)
nn.quantize(layer.gating, bits=args.quantize)
log("info", f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
log("info", f"loading the audio tokenizer {mimi_weights}")
generated_codebooks = lm_config.generated_codebooks
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
cfg_coef_conditioning = None
tts_model = TTSModel(
model,
audio_tokenizer,
text_tokenizer,
voice_repo=args.voice_repo,
temp=0.6,
cfg_coef=1,
max_padding=8,
initial_padding=2,
final_padding=2,
padding_bonus=0,
raw_config=raw_config,
)
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.0
mimi = tts_model.mimi
log("info", "reading input from stdin")
if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
gen = TTSGen(tts_model, all_attributes, on_frame=_on_frame)
def run():
log("info", "starting the inference loop")
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
while True:
if wav_frames.qsize() == 0:
break
time.sleep(1)
else:
run()
frames = []
while True:
try:
frames.append(wav_frames.get_nowait())
except queue.Empty:
break
wav = np.concat(frames, -1)
sphn.write_wav(args.out, wav, mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -1,7 +1,7 @@
# /// script # /// script
# requires-python = ">=3.12" # requires-python = ">=3.12"
# dependencies = [ # dependencies = [
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi", # "moshi==0.2.11",
# "torch", # "torch",
# "sphn", # "sphn",
# "sounddevice", # "sounddevice",
@ -11,20 +11,14 @@ import argparse
import sys import sys
import numpy as np import numpy as np
import queue
import sphn import sphn
import time
import torch import torch
from moshi.models.loaders import CheckpointInfo from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
sd.play(audio, sample_rate)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation" description="Run Kyutai TTS using the PyTorch implementation"
@ -50,12 +44,18 @@ def main():
help="The voice to use, relative to the voice repo root. " help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}", f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
) )
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args() args = parser.parse_args()
print("Loading model...") print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half checkpoint_info, n_q=32, temp=0.6, device=args.device
) )
if args.inp == "-": if args.inp == "-":
@ -66,31 +66,74 @@ def main():
with open(args.inp, "r") as fobj: with open(args.inp, "r") as fobj:
text = fobj.read().strip() text = fobj.read().strip()
# You could also generate multiple audios at once by passing a list of texts. # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1) entries = tts_model.prepare_script([text], padding_between=1)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice) voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation, # CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time. # so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes( condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
_frames_cnt = 0
print("Generating audio...")
# This doesn't do streaming generation, but the model allows it. For now, see Rust
# example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
with torch.no_grad():
audios = tts_model.mimi.decode(audio_tokens)
if args.out == "-": if args.out == "-":
print("Playing audio...") # Stream the audio to the speakers using sounddevice.
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate) import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
with tts_model.mimi.streaming(1):
tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
time.sleep(3)
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else: else:
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
print(f"Audio saved to {args.out}") def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
result = tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
with tts_model.mimi.streaming(1), torch.no_grad():
pcms = []
for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1))
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,261 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "moshi==0.2.11",
# "torch",
# "sphn",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import sys
import numpy as np
import queue
import sphn
import time
import torch
import typing as tp
from moshi.models.loaders import CheckpointInfo
from moshi.conditioners import dropout_all_conditions
from moshi.models.lm import LMGen
from moshi.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
ConditionAttributes,
script_to_entries,
)
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[torch.Tensor], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
prepared = tts_model.lm.condition_provider.prepare(attributes)
condition_tensors = tts_model.lm.condition_provider(prepared)
def _on_text_logits_hook(text_logits):
if tts_model.padding_bonus:
text_logits[..., tts_model.machine.token_ids.pad] += (
tts_model.padding_bonus
)
return text_logits
def _on_audio_hook(audio_tokens):
audio_offset = tts_model.lm.audio_offset
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[1]):
delay = delays[q + audio_offset]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[:, q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = torch.tensor(
out_tokens, dtype=torch.long, device=text_tokens.device
)
tts_model.lm.dep_q = tts_model.n_q
self.lm_gen = LMGen(
tts_model.lm,
temp=tts_model.temp,
temp_text=tts_model.temp,
cfg_coef=tts_model.cfg_coef,
condition_tensors=condition_tensors,
on_text_logits_hook=_on_text_logits_hook,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
cfg_is_masked_until=None,
cfg_is_no_text=True,
)
self.lm_gen.streaming_forever(1)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = torch.full(
(1, missing, 1),
self.tts_model.machine.token_ids.zero,
dtype=torch.long,
device=self.tts_model.lm.device,
)
frame = self.lm_gen.step(input_tokens)
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
@torch.no_grad()
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=args.device
)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
if args.out == "-":
# Stream the audio to the speakers using sounddevice.
import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
) and tts_model.mimi.streaming(1):
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else:
pcms = []
def _on_frame(frame: torch.Tensor):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0]))
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
with tts_model.mimi.streaming(1):
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__":
main()

178
scripts/tts_rust_server.py Normal file
View File

@ -0,0 +1,178 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# "sounddevice",
# "tqdm",
# ]
# ///
import argparse
import asyncio
import sys
from urllib.parse import urlencode
import msgpack
import numpy as np
import sounddevice as sd
import sphn
import tqdm
import websockets
SAMPLE_RATE = 24000
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
AUTH_TOKEN = "public_token"
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
accumulated_samples = 0
last_seconds = 0
async for message_bytes in websocket:
msg = msgpack.unpackb(message_bytes)
if msg["type"] == "Audio":
pcm = np.array(msg["pcm"]).astype(np.float32)
await output_queue.put(pcm)
accumulated_samples += len(msg["pcm"])
current_seconds = accumulated_samples // SAMPLE_RATE
if current_seconds > last_seconds:
pbar.update(current_seconds - last_seconds)
last_seconds = current_seconds
print("End of audio.")
await output_queue.put(None) # Signal end of audio
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
if out == "-":
should_exit = False
def audio_callback(outdata, _a, _b, _c):
nonlocal should_exit
try:
pcm_data = output_queue.get_nowait()
if pcm_data is not None:
outdata[:, 0] = pcm_data
else:
should_exit = True
outdata[:] = 0
except asyncio.QueueEmpty:
outdata[:] = 0
with sd.OutputStream(
samplerate=SAMPLE_RATE,
blocksize=1920,
channels=1,
callback=audio_callback,
):
while True:
if should_exit:
break
await asyncio.sleep(1)
else:
frames = []
while True:
item = await output_queue.get()
if item is None:
break
frames.append(item)
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
print(f"Saved audio to {out}")
async def read_lines_from_stdin():
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
loop = asyncio.get_running_loop()
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while True:
line = await reader.readline()
if not line:
break
yield line.decode().rstrip()
async def read_lines_from_file(path: str):
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
with open(path, "r", encoding="utf-8") as f:
for line in f:
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
await asyncio.to_thread(producer)
while True:
line = await queue.get()
if line is None:
break
yield line
async def get_lines(source: str):
if source == "-":
async for line in read_lines_from_stdin():
yield line
else:
async for line in read_lines_from_file(source):
yield line
async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--url",
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
args = parser.parse_args()
params = {"voice": args.voice, "format": "PcmMessagePack"}
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri)
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket:
print("connected")
async def send_loop():
print("go send")
async for line in get_lines(args.inp):
for word in line.split():
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
send_task = asyncio.create_task(send_loop())
await asyncio.gather(receive_task, output_audio_task, send_task)
if __name__ == "__main__":
asyncio.run(websocket_client())

File diff suppressed because one or more lines are too long