Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 96042c9983 | |||
| 7e92c6f28a | |||
| 8e51c6eab9 | |||
| 34e58b8055 | |||
| f23d1be027 | |||
|
|
cf97f8d863 | ||
|
|
09468c239a | ||
|
|
07729ed47e | ||
|
|
af2283de3f | ||
|
|
7dc926d50c | ||
|
|
ab8e8c59b7 | ||
|
|
5f17114618 | ||
|
|
405a82ba3f | ||
|
|
3b584b100c | ||
|
|
a98eb94ade | ||
|
|
a2f031deb5 | ||
|
|
66a33c989f | ||
|
|
89a2ced839 | ||
|
|
baf0c75bba | ||
|
|
952319de90 | ||
|
|
6d3bb6b1f1 | ||
|
|
12dbe36b0b | ||
|
|
cafac63222 | ||
|
|
7336d7a3da | ||
|
|
70500c620e | ||
|
|
f8e97aa4f3 | ||
|
|
91a4d120cb | ||
|
|
bfc200f6ee | ||
|
|
f9739881e6 | ||
|
|
99599fa408 | ||
|
|
3a4165a84f | ||
|
|
e9bac066ea | ||
|
|
eae5e17975 | ||
|
|
c1d248abba | ||
|
|
c6f262346f | ||
|
|
3573ee90af | ||
|
|
25574aa104 | ||
|
|
1cd9529f65 | ||
|
|
0ee2354176 | ||
|
|
dc8bffabe0 | ||
|
|
5f8e924176 | ||
|
|
d3bed09f9a | ||
|
|
ef52b8ef0f | ||
|
|
d92e4c2695 | ||
|
|
6c1e9f12cf | ||
|
|
236df522b8 | ||
|
|
20cf8d7365 | ||
|
|
ae575a04c6 | ||
|
|
433dca3751 | ||
|
|
07ac744609 | ||
|
|
96ff217437 | ||
|
|
7294fbcc3a |
83
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
83
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
name: Bug Report
|
||||
description: You found a bug.
|
||||
labels: ["bug", "triage"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
|
||||
- type: dropdown
|
||||
id: backend
|
||||
attributes:
|
||||
label: Backend impacted
|
||||
description: Which backend is concerned with your bug report?
|
||||
options:
|
||||
- The PyTorch implementation
|
||||
- The MLX implementation
|
||||
- The Rust implementation
|
||||
- Other / All
|
||||
default: 0
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating system
|
||||
description: What is your operating system?
|
||||
options:
|
||||
- Linux
|
||||
- Mac OS X
|
||||
- Windows (unsupported)
|
||||
default: 0
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: hardware
|
||||
attributes:
|
||||
label: Hardware
|
||||
description: What hardware are you using?
|
||||
options:
|
||||
- CPU
|
||||
- GPU with CUDA
|
||||
- Metal with MLX
|
||||
default: 0
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
description: Provide a detailed description of your bug.
|
||||
placeholder:
|
||||
value:
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: more_info
|
||||
attributes:
|
||||
label: Extra information
|
||||
description: Please provide any other relevant information, such as log extracts, code etc.
|
||||
placeholder:
|
||||
value:
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: env
|
||||
attributes:
|
||||
label: Environment
|
||||
description: Please provide any other relevant information, such as log extracts, code etc.
|
||||
placeholder:
|
||||
value: |
|
||||
Fill in the following information on your system.
|
||||
- Operating system version:
|
||||
|
||||
If the backend impacted is PyTorch:
|
||||
- Python version:
|
||||
- PyTorch version:
|
||||
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
|
||||
- GPU model and memory:
|
||||
|
||||
If the backend is MLX:
|
||||
- Mac model:
|
||||
validations:
|
||||
required: true
|
||||
40
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
name: Question
|
||||
description: You have a question about the codebase, the paper, or the implementation.
|
||||
labels: ["question", "triage"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Due diligence
|
||||
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
|
||||
options:
|
||||
- label: I have done my due diligence in trying to find the answer myself.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: backend
|
||||
attributes:
|
||||
label: Topic
|
||||
description: What is your question about?
|
||||
options:
|
||||
- The paper
|
||||
- The PyTorch implementation
|
||||
- The MLX implementation
|
||||
- The Rust implementation
|
||||
- Other / All
|
||||
default: 0
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Question
|
||||
description: What is your question?
|
||||
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
|
||||
value:
|
||||
validations:
|
||||
required: true
|
||||
2
.github/actions/moshi_build/action.yml
vendored
2
.github/actions/moshi_build/action.yml
vendored
|
|
@ -19,7 +19,7 @@ runs:
|
|||
. env/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install moshi==0.2.6
|
||||
pip install moshi==0.2.7
|
||||
pip install pre-commit
|
||||
- name: Setup env
|
||||
shell: bash
|
||||
|
|
|
|||
4
.github/workflows/precommit.yml
vendored
4
.github/workflows/precommit.yml
vendored
|
|
@ -13,5 +13,5 @@ jobs:
|
|||
- uses: actions/checkout@v2
|
||||
- uses: ./.github/actions/moshi_build
|
||||
- run: |
|
||||
. env/bin/activate
|
||||
bash .git/hooks/pre-commit
|
||||
source env/bin/activate
|
||||
pre-commit run --all-files
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -192,3 +192,4 @@ cython_debug/
|
|||
# refer to https://docs.cursor.com/context/ignore-files
|
||||
.cursorignore
|
||||
.cursorindexingignore
|
||||
out*.wav
|
||||
|
|
|
|||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.analysis.typeCheckingMode": "standard"
|
||||
}
|
||||
56
FAQ.md
Normal file
56
FAQ.md
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# FAQ
|
||||
|
||||
Here is the answer to a number of frequently asked questions.
|
||||
|
||||
### Torch compilation issues
|
||||
|
||||
With some PyTorch/triton versions, one might encounter compilation errors
|
||||
like the following:
|
||||
```
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
File "site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1153, in make_launcher
|
||||
"launch_enter_hook": binary.__class__.launch_enter_hook,
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
|
||||
```
|
||||
|
||||
If that's the case, you can disable torch compilation by setting the following
|
||||
environment variable.
|
||||
```bash
|
||||
export NO_TORCH_COMPILE=1
|
||||
```
|
||||
|
||||
### Issues installing the sentencepiece dependency
|
||||
|
||||
On some linux distributions (arch) or on macos, the local version of cmake can
|
||||
be too recent for the sentencepiece dependency.
|
||||
|
||||
```
|
||||
CMake Error at CMakeLists.txt:15 (cmake_minimum_required):
|
||||
Compatibility with CMake < 3.5 has been removed from CMake.
|
||||
```
|
||||
|
||||
You can either downgrade your cmake version, e.g. 3.31.0 on arch works or try
|
||||
setting `CMAKE_POLICY_VERSION_MINIMUM=3.5`.
|
||||
|
||||
If you run into some errors when compiling the sentencepiece rust bindings,
|
||||
these could also be due to gcc being too recent, e.g. gcc 15. You can get
|
||||
around this by using gcc-13, e.g. by setting the following after installing
|
||||
the proper gcc packages.
|
||||
```bash
|
||||
export CMAKE_C_COMPILER=/usr/bin/gcc-13
|
||||
export CMAKE_CXX_COMPILER=/usr/bin/g++-13
|
||||
CC=gcc-13 CXX=g++-13 cargo build --release
|
||||
```
|
||||
|
||||
Alternatively you can set `CXXFLAGS="-include cstdint"`, see this
|
||||
[issue](https://github.com/google/sentencepiece/issues/1108).
|
||||
|
||||
### Will you release training code?
|
||||
|
||||
Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
|
||||
This code has not been adapted to the Speech-To-Text and Text-To-Speech models
|
||||
yet, but it should be a good starting point.
|
||||
|
||||
|
||||
171
README.md
171
README.md
|
|
@ -1,21 +1,24 @@
|
|||
# Delayed Streams Modeling: Kyutai STT & TTS
|
||||
|
||||
This repo contains instructions and examples of how to run
|
||||
[Kyutai Speech-To-Text](#kyutai-speech-to-text)
|
||||
and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models.
|
||||
These models are powered by delayed streams modeling (DSM),
|
||||
a flexible formulation for streaming, multimodal sequence-to-sequence learning.
|
||||
See also [Unmute](https://github.com/kyutai-labs/unmute), an voice AI system built using Kyutai STT and Kyutai TTS.
|
||||
|
||||
But wait, what is "Delayed Streams Modeling"? It is a technique for solving many streaming X-to-Y tasks (with X, Y in `{speech, text}`)
|
||||
that formalize the approach we had with Moshi and Hibiki. A pre-print paper is coming soon!
|
||||
|
||||
## Kyutai Speech-To-Text
|
||||
|
||||
<a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb">
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
|
||||
This repo contains instructions and examples of how to run Kyutai Speech-To-Text models.
|
||||
These models are powered by delayed streams modeling (DSM),
|
||||
a flexible formulation for streaming, multimodal sequence-to-sequence learning.
|
||||
|
||||
Text-to-speech models based on DSM coming soon!
|
||||
[Sign up here](https://docs.google.com/forms/d/15sB4zyfuwyXTii4OM74hFGkk4DlDNynJ9xywnaEzE4I/edit)
|
||||
to be notified when we open-source text-to-speech and [Unmute](https://unmute.sh).
|
||||
|
||||
## Kyutai Speech-To-Text
|
||||
|
||||
**More details can be found on the [project page](https://kyutai.org/next/stt).**
|
||||
|
||||
Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps.
|
||||
|
|
@ -48,16 +51,17 @@ Here is how to choose which one to use:
|
|||
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
|
||||
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||
|
||||
### PyTorch implementation
|
||||
<details>
|
||||
<summary>PyTorch implementation</summary>
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb">
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
|
||||
[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb).
|
||||
[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb).
|
||||
|
||||
This requires the [moshi package](https://pypi.org/project/moshi/)
|
||||
with version 0.2.6 or later, which can be installed via pip.
|
||||
|
|
@ -75,9 +79,9 @@ Additionally, we provide two scripts that highlight different usage scenarios. T
|
|||
|
||||
```bash
|
||||
uv run \
|
||||
scripts/transcribe_from_file_via_pytorch.py \
|
||||
scripts/stt_from_file_pytorch.py \
|
||||
--hf-repo kyutai/stt-2.6b-en \
|
||||
--file audio/bria.mp3
|
||||
audio/bria.mp3
|
||||
```
|
||||
|
||||
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
|
||||
|
|
@ -89,7 +93,7 @@ uv run scripts/evaluate_on_dataset.py \
|
|||
|
||||
Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model:
|
||||
```bash
|
||||
uv run scripts/transcribe_from_file_via_pytorch_with_prompt.py \
|
||||
uv run scripts/stt_from_file_pytorch_with_prompt.py \
|
||||
--hf-repo kyutai/stt-2.6b-en \
|
||||
--file bria.mp3 \
|
||||
--prompt_file ./audio/loonah.mp3 \
|
||||
|
|
@ -103,15 +107,17 @@ In the heart of an ancient forest, where the trees whispered secrets of the past
|
|||
|
||||
Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language.
|
||||
However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided.
|
||||
</details>
|
||||
|
||||
### Rust server
|
||||
<details>
|
||||
<summary>Rust server</summary>
|
||||
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
|
||||
The Rust implementation provides a server that can process multiple streaming
|
||||
queries in parallel. Dependening on the amount of memory on your GPU, you may
|
||||
queries in parallel. Depending on the amount of memory on your GPU, you may
|
||||
have to adjust the batch size from the config file. For a L40S GPU, a batch size
|
||||
of 64 works well and requests can be processed at 3x real-time speed.
|
||||
|
||||
|
|
@ -135,20 +141,22 @@ moshi-server worker --config configs/config-stt-en_fr-hf.toml
|
|||
|
||||
Once the server has started you can transcribe audio from your microphone with the following script.
|
||||
```bash
|
||||
uv run scripts/transcribe_from_mic_via_rust_server.py
|
||||
uv run scripts/stt_from_mic_rust_server.py
|
||||
```
|
||||
|
||||
We also provide a script for transcribing from an audio file.
|
||||
```bash
|
||||
uv run scripts/transcribe_from_file_via_rust_server.py audio/bria.mp3
|
||||
uv run scripts/stt_from_file_rust_server.py audio/bria.mp3
|
||||
```
|
||||
|
||||
The script limits the decoding speed to simulates real-time processing of the audio.
|
||||
Faster processing can be triggered by setting
|
||||
the real-time factor, e.g. `--rtf 1000` will process
|
||||
the data as fast as possible.
|
||||
</details>
|
||||
|
||||
### Rust standalone
|
||||
<details>
|
||||
<summary>Rust standalone</summary>
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
|
|
@ -157,12 +165,14 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r
|
|||
This can be used as follows:
|
||||
```bash
|
||||
cd stt-rs
|
||||
cargo run --features cuda -r -- audio/bria.mp3
|
||||
cargo run --features cuda -r -- ../audio/bria.mp3
|
||||
```
|
||||
You can get the timestamps by adding the `--timestamps` flag, and see the output
|
||||
of the semantic VAD by adding the `--vad` flag.
|
||||
</details>
|
||||
|
||||
### MLX implementation
|
||||
<details>
|
||||
<summary>MLX implementation</summary>
|
||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
|
|
@ -185,16 +195,119 @@ and just prefix the command above with `uvx --with moshi-mlx`.
|
|||
If you want to transcribe audio from your microphone, use:
|
||||
|
||||
```bash
|
||||
python scripts/transcribe_from_mic_via_mlx.py
|
||||
python scripts/stt_from_mic_mlx.py
|
||||
```
|
||||
|
||||
The MLX models can also be used in swift using the [moshi-swift
|
||||
codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been
|
||||
tested to work fine on an iPhone 16 Pro.
|
||||
</details>
|
||||
|
||||
## Text-to-Speech
|
||||
## Kyutai Text-to-Speech
|
||||
|
||||
We're in the process of open-sourcing our TTS models. Check back for updates!
|
||||
<a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||
</a>
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
**More details can be found on the [project page](https://kyutai.org/next/tts).**
|
||||
|
||||
We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use:
|
||||
|
||||
- PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
|
||||
- Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute.
|
||||
- MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||
|
||||
<details>
|
||||
<summary>PyTorch implementation</summary>
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script:
|
||||
|
||||
```bash
|
||||
# From stdin, plays audio immediately
|
||||
echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
|
||||
|
||||
# From text file to audio file
|
||||
python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
|
||||
```
|
||||
|
||||
The `tts_pytorch.py` script waits for all the text to be available before
|
||||
starting the audio generation. A fully streaming implementation is available in
|
||||
the `tts_pytorch_streaming.py` script, which can be used as follows:
|
||||
|
||||
```bash
|
||||
echo "Hey, how are you?" | python scripts/tts_pytorch_streaming.py audio_output.wav
|
||||
```
|
||||
|
||||
This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
|
||||
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
|
||||
and just prefix the command above with `uvx --with moshi`.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Rust server</summary>
|
||||
|
||||
|
||||
The Rust implementation provides a server that can process multiple streaming
|
||||
queries in parallel.
|
||||
|
||||
Installing the Rust server is a bit tricky because it uses our Python implementation under the hood,
|
||||
which also requires installing the Python dependencies.
|
||||
Use the [start_tts.sh](https://github.com/kyutai-labs/unmute/blob/main/dockerless/start_tts.sh) script to properly install the Rust server.
|
||||
If you already installed the `moshi-server` crate before and it's not working, you might need to force a reinstall by running `cargo uninstall moshi-server` first.
|
||||
Feel free to open an issue if the installation is still broken.
|
||||
|
||||
Once installed, the server can be started via the following command using the config file
|
||||
from this repository.
|
||||
|
||||
```bash
|
||||
moshi-server worker --config configs/config-tts.toml
|
||||
```
|
||||
|
||||
Once the server has started you can connect to it using our script as follows:
|
||||
```bash
|
||||
# From stdin, plays audio immediately
|
||||
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
|
||||
|
||||
# From text file to audio file
|
||||
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MLX implementation</summary>
|
||||
|
||||
[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
|
||||
hardware acceleration on Apple silicon.
|
||||
|
||||
Use our example script to run Kyutai TTS on MLX.
|
||||
The script takes text from stdin or a file and can output to a file or stream the resulting audio.
|
||||
When streaming the output, if the model is not fast enough to keep with
|
||||
real-time, you can use the `--quantize 8` or `--quantize 4` flags to quantize
|
||||
the model resulting in faster inference.
|
||||
|
||||
```bash
|
||||
# From stdin, plays audio immediately
|
||||
echo "Hey, how are you?" | python scripts/tts_mlx.py - - --quantize 8
|
||||
|
||||
# From text file to audio file
|
||||
python scripts/tts_mlx.py text_to_say.txt audio_output.wav
|
||||
```
|
||||
|
||||
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip.
|
||||
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
|
||||
and just prefix the command above with `uvx --with moshi-mlx`.
|
||||
</details>
|
||||
|
||||
## FAQ
|
||||
|
||||
Checkout the [Frequently Asked Questions](FAQ.md) section before opening an issue.
|
||||
|
||||
## License
|
||||
|
||||
|
|
@ -214,4 +327,4 @@ pip install pre-commit
|
|||
pre-commit install
|
||||
```
|
||||
|
||||
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.
|
||||
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.
|
||||
|
|
|
|||
298
app/api_server.py
Normal file
298
app/api_server.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenAI-Compatible Kyutai TTS API Server with Model Caching
|
||||
Improved version that loads the model once and keeps it in memory
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global model variables - loaded once at startup
|
||||
tts_model = None
|
||||
device = None
|
||||
sample_rate = None
|
||||
|
||||
class SpeechRequest(BaseModel):
|
||||
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
|
||||
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
|
||||
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
|
||||
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
|
||||
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
|
||||
|
||||
app = FastAPI(
|
||||
title="OpenAI-Compatible TTS API (Cached)",
|
||||
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
|
||||
version="2.0.0"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
OUTPUT_DIR = Path("/app/api_output")
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
def load_tts_model():
|
||||
"""Load TTS model once at startup and keep in memory"""
|
||||
global tts_model, device, sample_rate
|
||||
|
||||
if tts_model is not None:
|
||||
logger.info("TTS model already loaded")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
|
||||
|
||||
# Import Kyutai TTS modules
|
||||
from moshi.models.loaders import CheckpointInfo
|
||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load the TTS model
|
||||
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
|
||||
|
||||
tts_model = TTSModel.from_checkpoint_info(
|
||||
checkpoint_info,
|
||||
n_q=32,
|
||||
temp=0.6,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Get sample rate
|
||||
sample_rate = tts_model.mimi.sample_rate
|
||||
|
||||
logger.info(f"✅ TTS model loaded successfully!")
|
||||
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
|
||||
logger.info(f" Device: {device}")
|
||||
logger.info(f" Sample Rate: {sample_rate}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to load TTS model: {e}")
|
||||
raise
|
||||
|
||||
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
|
||||
"""Generate audio using cached TTS model"""
|
||||
global tts_model, device, sample_rate
|
||||
|
||||
if tts_model is None:
|
||||
raise HTTPException(status_code=500, detail="TTS model not loaded")
|
||||
|
||||
try:
|
||||
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
|
||||
|
||||
# Prepare the script (text input)
|
||||
entries = tts_model.prepare_script([text], padding_between=1)
|
||||
|
||||
# Voice mapping for OpenAI compatibility
|
||||
voice_mapping = {
|
||||
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
|
||||
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
|
||||
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
|
||||
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
|
||||
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
|
||||
}
|
||||
|
||||
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
|
||||
|
||||
try:
|
||||
voice_path = tts_model.get_voice_path(selected_voice)
|
||||
except:
|
||||
# Fallback to default if voice not found
|
||||
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||
|
||||
# Prepare condition attributes
|
||||
condition_attributes = tts_model.make_condition_attributes(
|
||||
[voice_path], cfg_coef=2.0
|
||||
)
|
||||
|
||||
# Generate audio
|
||||
pcms = []
|
||||
|
||||
def on_frame(frame):
|
||||
if (frame != -1).all():
|
||||
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
|
||||
|
||||
all_entries = [entries]
|
||||
all_condition_attributes = [condition_attributes]
|
||||
|
||||
with tts_model.mimi.streaming(len(all_entries)):
|
||||
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
|
||||
|
||||
# Concatenate all audio frames
|
||||
if pcms:
|
||||
import numpy as np
|
||||
audio = np.concatenate(pcms, axis=-1)
|
||||
|
||||
# Apply speed adjustment if needed
|
||||
if speed != 1.0:
|
||||
# Simple speed adjustment by resampling
|
||||
from scipy import signal
|
||||
audio_length = len(audio)
|
||||
new_length = int(audio_length / speed)
|
||||
audio = signal.resample(audio, new_length)
|
||||
|
||||
# Convert to bytes
|
||||
audio_bytes = io.BytesIO()
|
||||
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
|
||||
audio_bytes.seek(0)
|
||||
|
||||
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
|
||||
return audio_bytes.read()
|
||||
else:
|
||||
raise Exception("No audio frames generated")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ TTS generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
|
||||
|
||||
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
|
||||
"""Convert WAV audio to requested format using ffmpeg"""
|
||||
try:
|
||||
if output_format == "wav":
|
||||
return audio_wav_bytes
|
||||
|
||||
# Use ffmpeg to convert
|
||||
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
input=audio_wav_bytes,
|
||||
capture_output=True,
|
||||
check=True
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Audio conversion failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
|
||||
|
||||
@app.post("/v1/audio/speech")
|
||||
async def create_speech(request: SpeechRequest):
|
||||
"""
|
||||
OpenAI-compatible audio speech endpoint
|
||||
Uses cached TTS model for fast generation
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Generate audio with cached model
|
||||
audio_wav_bytes = generate_audio_fast(
|
||||
text=request.input,
|
||||
voice=request.voice,
|
||||
speed=request.speed
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
|
||||
|
||||
# Set appropriate content type
|
||||
content_types = {
|
||||
"mp3": "audio/mpeg",
|
||||
"opus": "audio/opus",
|
||||
"aac": "audio/aac",
|
||||
"flac": "audio/flac",
|
||||
"wav": "audio/wav",
|
||||
"pcm": "audio/pcm"
|
||||
}
|
||||
|
||||
return Response(
|
||||
content=audio_data,
|
||||
media_type=content_types.get(request.response_format, "audio/wav"),
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Generation-Time": str(generation_time)
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Speech generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""List available models (OpenAI-compatible)"""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "tts-1",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "kyutai",
|
||||
"permission": [],
|
||||
"root": "tts-1",
|
||||
"parent": None
|
||||
},
|
||||
{
|
||||
"id": "tts-1-hd",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "kyutai",
|
||||
"permission": [],
|
||||
"root": "tts-1-hd",
|
||||
"parent": None
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint with model status"""
|
||||
model_loaded = tts_model is not None
|
||||
return {
|
||||
"status": "healthy" if model_loaded else "loading",
|
||||
"model_loaded": model_loaded,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"device": str(device) if device else None,
|
||||
"service": "kyutai-tts-openai-compatible-cached"
|
||||
}
|
||||
|
||||
@app.get("/reload-model")
|
||||
async def reload_model():
|
||||
"""Reload the TTS model (admin endpoint)"""
|
||||
global tts_model
|
||||
try:
|
||||
tts_model = None
|
||||
load_tts_model()
|
||||
return {"status": "success", "message": "Model reloaded successfully"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Load model on startup"""
|
||||
logger.info("🚀 Starting TTS API server with model caching...")
|
||||
load_tts_model()
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
67
app/dependency_check.py
Normal file
67
app/dependency_check.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check if all Kyutai TTS dependencies are properly installed
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
def check_dependencies():
|
||||
print("🔍 Checking Kyutai TTS Dependencies")
|
||||
print("=" * 40)
|
||||
|
||||
dependencies = [
|
||||
"torch",
|
||||
"numpy",
|
||||
"einops",
|
||||
"transformers",
|
||||
"accelerate",
|
||||
"soundfile",
|
||||
"librosa",
|
||||
"huggingface_hub",
|
||||
"moshi",
|
||||
"sphn"
|
||||
]
|
||||
|
||||
missing = []
|
||||
installed = []
|
||||
|
||||
for dep in dependencies:
|
||||
try:
|
||||
__import__(dep)
|
||||
installed.append(dep)
|
||||
print(f"✓ {dep}")
|
||||
except ImportError as e:
|
||||
missing.append((dep, str(e)))
|
||||
print(f"✗ {dep}: {e}")
|
||||
|
||||
print(f"\n📊 Summary:")
|
||||
print(f"✓ Installed: {len(installed)}")
|
||||
print(f"✗ Missing: {len(missing)}")
|
||||
|
||||
if missing:
|
||||
print(f"\n🔧 To fix missing dependencies:")
|
||||
for dep, error in missing:
|
||||
print(f"pip install {dep}")
|
||||
|
||||
print(f"\n🧪 Testing Kyutai TTS imports:")
|
||||
try:
|
||||
from moshi.models.loaders import CheckpointInfo
|
||||
print("✓ CheckpointInfo import successful")
|
||||
except Exception as e:
|
||||
print(f"✗ CheckpointInfo import failed: {e}")
|
||||
|
||||
try:
|
||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||
print("✓ TTSModel imports successful")
|
||||
except Exception as e:
|
||||
print(f"✗ TTSModel imports failed: {e}")
|
||||
|
||||
return len(missing) == 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = check_dependencies()
|
||||
if success:
|
||||
print("\n🎉 All dependencies are installed correctly!")
|
||||
else:
|
||||
print("\n❌ Some dependencies are missing. Please install them first.")
|
||||
sys.exit(1)
|
||||
59
app/scripts/tts_runner.py
Normal file
59
app/scripts/tts_runner.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Kyutai TTS PyTorch Runner
|
||||
Dockerized implementation for text-to-speech generation
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
|
||||
parser.add_argument('input_file', help='Input text file or "-" for stdin')
|
||||
parser.add_argument('output_file', help='Output audio file')
|
||||
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
|
||||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Using device: {args.device}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
|
||||
# Handle stdin input
|
||||
if args.input_file == '-':
|
||||
# Read from stdin and create temporary file
|
||||
text = sys.stdin.read().strip()
|
||||
temp_file = '/tmp/temp_input.txt'
|
||||
with open(temp_file, 'w') as f:
|
||||
f.write(text)
|
||||
input_file = temp_file
|
||||
else:
|
||||
input_file = args.input_file
|
||||
|
||||
# Check if the original TTS script exists
|
||||
tts_script = Path('/app/scripts/tts_pytorch.py')
|
||||
if tts_script.exists():
|
||||
print("Using original TTS script from Kyutai repository")
|
||||
import subprocess
|
||||
cmd = ['python', str(tts_script), input_file, args.output_file]
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
print("Using moshi package for TTS generation")
|
||||
import subprocess
|
||||
cmd = [
|
||||
'python', '-m', 'moshi.run_inference',
|
||||
'--hf-repo', args.model,
|
||||
input_file,
|
||||
args.output_file
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"Error: {result.stderr}")
|
||||
sys.exit(1)
|
||||
print(f"Audio generated: {args.output_file}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
EOF
|
||||
20
configs/config-tts.toml
Normal file
20
configs/config-tts.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
static_dir = "./static/"
|
||||
log_dir = "$HOME/tmp/tts-logs"
|
||||
instance_name = "tts"
|
||||
authorized_ids = ["public_token"]
|
||||
|
||||
[modules.tts_py]
|
||||
type = "Py"
|
||||
path = "/api/tts_streaming"
|
||||
text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
|
||||
batch_size = 8 # Adjust to your GPU memory capacity
|
||||
text_bos_token = 1
|
||||
|
||||
[modules.tts_py.py]
|
||||
log_folder = "$HOME/tmp/moshi-server-logs"
|
||||
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
|
||||
default_voice = "unmute-prod-website/default_voice.wav"
|
||||
cfg_coef = 2.0
|
||||
cfg_is_no_text = true
|
||||
padding_between = 1
|
||||
n_q = 24
|
||||
78
install.sh
Normal file
78
install.sh
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Set environment variables
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export PYTHONUNBUFFERED=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
# Install system dependencies
|
||||
apt-get update && apt-get install -y \
|
||||
wget \
|
||||
curl \
|
||||
git \
|
||||
build-essential \
|
||||
libsndfile1 \
|
||||
ffmpeg \
|
||||
sox \
|
||||
alsa-utils \
|
||||
pulseaudio \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Install Python dependencies first (for better caching)
|
||||
pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Create virtual environment
|
||||
apt install python3.12-venv python3.12-dev
|
||||
python3.12 -m venv ~/venv-tts-kyutai
|
||||
source ~/venv-tts-kyutai/bin/activate
|
||||
|
||||
# Install Python dependencies first (for better caching)
|
||||
pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Install PyTorch with CUDA support for Python 3.12
|
||||
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# Install core dependencies
|
||||
pip install --no-cache-dir \
|
||||
numpy \
|
||||
scipy \
|
||||
librosa \
|
||||
soundfile \
|
||||
huggingface_hub \
|
||||
einops \
|
||||
transformers \
|
||||
accelerate
|
||||
|
||||
# Install API dependencies
|
||||
pip install --no-cache-dir \
|
||||
fastapi \
|
||||
uvicorn[standard] \
|
||||
python-multipart \
|
||||
pydantic
|
||||
|
||||
# Install moshi package with all dependencies (following Colab notebook)
|
||||
pip install --no-cache-dir 'sphn<0.2'
|
||||
pip install --no-cache-dir "moshi==0.2.8"
|
||||
|
||||
# Create directories for input/output
|
||||
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||
|
||||
# Download the Kyutai delayed-streams-modeling repository
|
||||
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
|
||||
|
||||
# Copy the TTS script from the repository
|
||||
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
|
||||
|
||||
# Create directories for input/output
|
||||
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||
|
||||
# Download the Kyutai delayed-streams-modeling repository
|
||||
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
|
||||
|
||||
# Copy the TTS script from the repository
|
||||
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
|
||||
|
||||
# Create directories for input/output
|
||||
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||
|
||||
# Start TTS-Server
|
||||
python /app/api_server.py
|
||||
100
scripts/stt_from_file_mlx.py
Normal file
100
scripts/stt_from_file_mlx.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "moshi_mlx==0.2.12",
|
||||
# "numpy",
|
||||
# "sentencepiece",
|
||||
# "sounddevice",
|
||||
# "sphn",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import sentencepiece
|
||||
import sphn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from moshi_mlx import models, utils
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("in_file", help="The file to transcribe.")
|
||||
parser.add_argument("--max-steps", default=4096)
|
||||
parser.add_argument("--hf-repo")
|
||||
parser.add_argument(
|
||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
||||
if args.hf_repo is None:
|
||||
if args.vad:
|
||||
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||
else:
|
||||
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||
with open(lm_config, "r") as fobj:
|
||||
lm_config = json.load(fobj)
|
||||
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
|
||||
moshi_name = lm_config.get("moshi_name", "model.safetensors")
|
||||
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
|
||||
text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
|
||||
|
||||
lm_config = models.LmConfig.from_config_dict(lm_config)
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
if moshi_weights.endswith(".q4.safetensors"):
|
||||
nn.quantize(model, bits=4, group_size=32)
|
||||
elif moshi_weights.endswith(".q8.safetensors"):
|
||||
nn.quantize(model, bits=8, group_size=64)
|
||||
|
||||
print(f"loading model weights from {moshi_weights}")
|
||||
if args.hf_repo.endswith("-candle"):
|
||||
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
|
||||
else:
|
||||
model.load_weights(moshi_weights, strict=True)
|
||||
|
||||
print(f"loading the text tokenizer from {text_tokenizer}")
|
||||
text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer) # type: ignore
|
||||
|
||||
print(f"loading the audio tokenizer {mimi_weights}")
|
||||
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32))
|
||||
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
|
||||
print("warming up the model")
|
||||
model.warmup()
|
||||
gen = models.LmGen(
|
||||
model=model,
|
||||
max_steps=args.max_steps,
|
||||
text_sampler=utils.Sampler(top_k=25, temp=0),
|
||||
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
|
||||
check=False,
|
||||
)
|
||||
|
||||
print(f"starting inference {audio.shape}")
|
||||
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
|
||||
last_print_was_vad = False
|
||||
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
|
||||
block = audio[:, None, start_idx : start_idx + 1920]
|
||||
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
|
||||
if args.vad:
|
||||
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||
if vad_heads:
|
||||
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||
if pr_vad > 0.5 and not last_print_was_vad:
|
||||
print(" [end of turn detected]")
|
||||
last_print_was_vad = True
|
||||
else:
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
text_token = text_token[0].item()
|
||||
audio_tokens = gen.last_audio_tokens()
|
||||
_text = None
|
||||
if text_token not in (0, 3):
|
||||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
last_print_was_vad = False
|
||||
print()
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# "julius",
|
||||
# "librosa",
|
||||
# "soundfile",
|
||||
# "moshi",
|
||||
# "moshi==0.2.11",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
|
|
@ -20,8 +20,8 @@ import math
|
|||
import julius
|
||||
import moshi.models
|
||||
import sphn
|
||||
import time
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -128,6 +128,9 @@ def tokens_to_timestamped_text(
|
|||
|
||||
|
||||
def main(args):
|
||||
if args.vad and args.hf_repo is None:
|
||||
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||
|
||||
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
||||
args.hf_repo,
|
||||
moshi_weights=args.moshi_weight,
|
||||
|
|
@ -150,7 +153,7 @@ def main(args):
|
|||
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
|
||||
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
|
||||
|
||||
audio, input_sample_rate = sphn.read(args.file)
|
||||
audio, input_sample_rate = sphn.read(args.in_file)
|
||||
audio = torch.from_numpy(audio).to(args.device)
|
||||
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
|
||||
if audio.shape[-1] % mimi.frame_size != 0:
|
||||
|
|
@ -171,16 +174,35 @@ def main(args):
|
|||
itertools.repeat(silence_chunk, n_suffix_chunks),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
nchunks = 0
|
||||
last_print_was_vad = False
|
||||
with mimi.streaming(1), lm_gen.streaming(1):
|
||||
for audio_chunk in tqdm.tqdm(chunks):
|
||||
for audio_chunk in chunks:
|
||||
nchunks += 1
|
||||
audio_tokens = mimi.encode(audio_chunk)
|
||||
text_tokens = lm_gen.step(audio_tokens)
|
||||
if text_tokens is not None:
|
||||
text_tokens_accum.append(text_tokens)
|
||||
|
||||
print(tokenizer.decode(text_tokens.numpy().tolist()))
|
||||
if args.vad:
|
||||
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
|
||||
if vad_heads:
|
||||
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
|
||||
if pr_vad > 0.5 and not last_print_was_vad:
|
||||
print(" [end of turn detected]")
|
||||
last_print_was_vad = True
|
||||
else:
|
||||
text_tokens = lm_gen.step(audio_tokens)
|
||||
text_token = text_tokens[0, 0, 0].cpu().item()
|
||||
if text_token not in (0, 3):
|
||||
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
last_print_was_vad = False
|
||||
text_tokens_accum.append(text_tokens)
|
||||
|
||||
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
||||
dt = time.time() - start_time
|
||||
print(
|
||||
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
|
||||
)
|
||||
timed_text = tokens_to_timestamped_text(
|
||||
utterance_tokens,
|
||||
tokenizer,
|
||||
|
|
@ -211,6 +233,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--config-path", type=str, help="Path to a local config file.", default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
"""An example script that illustrates how one can prompt Kyutai STT models."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
from collections import deque
|
||||
|
|
@ -14,15 +13,7 @@ import tqdm
|
|||
|
||||
|
||||
class PromptHook:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
prefix,
|
||||
padding_tokens=(
|
||||
0,
|
||||
3,
|
||||
),
|
||||
):
|
||||
def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)):
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
|
||||
self.padding_tokens = padding_tokens
|
||||
|
|
@ -141,7 +132,7 @@ def main(args):
|
|||
prompt_frames = audio_prompt.shape[1] // mimi.frame_size
|
||||
no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
|
||||
no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
|
||||
text_tokens = text_tokens[prompt_frames + no_prompt_offset:]
|
||||
text_tokens = text_tokens[prompt_frames + no_prompt_offset :]
|
||||
|
||||
text = tokenizer.decode(
|
||||
text_tokens[text_tokens > padding_token_id].numpy().tolist()
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "moshi_mlx",
|
||||
# "moshi_mlx==0.2.12",
|
||||
# "numpy",
|
||||
# "rustymimi",
|
||||
# "sentencepiece",
|
||||
|
|
@ -25,9 +25,17 @@ from moshi_mlx import models, utils
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-steps", default=4096)
|
||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
||||
parser.add_argument("--hf-repo")
|
||||
parser.add_argument(
|
||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.hf_repo is None:
|
||||
if args.vad:
|
||||
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||
else:
|
||||
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||
with open(lm_config, "r") as fobj:
|
||||
lm_config = json.load(fobj)
|
||||
|
|
@ -45,7 +53,10 @@ if __name__ == "__main__":
|
|||
nn.quantize(model, bits=8, group_size=64)
|
||||
|
||||
print(f"loading model weights from {moshi_weights}")
|
||||
model.load_weights(moshi_weights, strict=True)
|
||||
if args.hf_repo.endswith("-candle"):
|
||||
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
|
||||
else:
|
||||
model.load_weights(moshi_weights, strict=True)
|
||||
|
||||
print(f"loading the text tokenizer from {tokenizer}")
|
||||
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore
|
||||
|
|
@ -71,6 +82,7 @@ if __name__ == "__main__":
|
|||
block_queue.put(indata.copy())
|
||||
|
||||
print("recording audio from microphone, speak to get your words transcribed")
|
||||
last_print_was_vad = False
|
||||
with sd.InputStream(
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
|
|
@ -85,7 +97,15 @@ if __name__ == "__main__":
|
|||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
||||
:, :, :other_codebooks
|
||||
]
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
if args.vad:
|
||||
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||
if vad_heads:
|
||||
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||
if pr_vad > 0.5 and not last_print_was_vad:
|
||||
print(" [end of turn detected]")
|
||||
last_print_was_vad = True
|
||||
else:
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
text_token = text_token[0].item()
|
||||
audio_tokens = gen.last_audio_tokens()
|
||||
_text = None
|
||||
|
|
@ -93,3 +113,4 @@ if __name__ == "__main__":
|
|||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
last_print_was_vad = False
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "moshi_mlx>=0.2.8",
|
||||
# "moshi_mlx==0.2.12",
|
||||
# "numpy",
|
||||
# "sounddevice",
|
||||
# ]
|
||||
|
|
@ -10,23 +10,24 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import queue
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
import sphn
|
||||
import time
|
||||
|
||||
import sounddevice as sd
|
||||
|
||||
from moshi_mlx.client_utils import make_log
|
||||
import sphn
|
||||
from moshi_mlx import models
|
||||
from moshi_mlx.client_utils import make_log
|
||||
from moshi_mlx.models.tts import (
|
||||
DEFAULT_DSM_TTS_REPO,
|
||||
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
TTSModel,
|
||||
)
|
||||
from moshi_mlx.utils.loaders import hf_get
|
||||
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
||||
|
||||
|
||||
def log(level: str, msg: str):
|
||||
|
|
@ -34,15 +35,32 @@ def log(level: str, msg: str):
|
|||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the MLX implementation"
|
||||
)
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
|
||||
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.")
|
||||
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.")
|
||||
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo",
|
||||
type=str,
|
||||
default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice-repo",
|
||||
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=int,
|
||||
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(299792458)
|
||||
|
|
@ -58,6 +76,9 @@ def main():
|
|||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||
# The following line gets around it for now.
|
||||
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
|
||||
|
|
@ -96,7 +117,7 @@ def main():
|
|||
if tts_model.valid_cfg_conditionings:
|
||||
# Model was trained with CFG distillation.
|
||||
cfg_coef_conditioning = tts_model.cfg_coef
|
||||
tts_model.cfg_coef = 1.
|
||||
tts_model.cfg_coef = 1.0
|
||||
cfg_is_no_text = False
|
||||
cfg_is_no_prefix = False
|
||||
else:
|
||||
|
|
@ -105,21 +126,29 @@ def main():
|
|||
mimi = tts_model.mimi
|
||||
|
||||
log("info", f"reading input from {args.inp}")
|
||||
with open(args.inp, "r") as fobj:
|
||||
text_to_tts = fobj.read().strip()
|
||||
if args.inp == "-":
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
text_to_tts = sys.stdin.read().strip()
|
||||
else:
|
||||
with open(args.inp, "r") as fobj:
|
||||
text_to_tts = fobj.read().strip()
|
||||
|
||||
all_entries = [tts_model.prepare_script([text_to_tts])]
|
||||
if tts_model.multi_speaker:
|
||||
voices = [tts_model.get_voice_path(args.voice)]
|
||||
else:
|
||||
voices = []
|
||||
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
|
||||
all_attributes = [
|
||||
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||
]
|
||||
|
||||
wav_frames = queue.Queue()
|
||||
def _on_audio_hook(audio_tokens):
|
||||
if (audio_tokens == -1).any():
|
||||
|
||||
def _on_frame(frame):
|
||||
if (frame == -1).any():
|
||||
return
|
||||
_pcm = tts_model.mimi.decode_step(audio_tokens[None, :, None])
|
||||
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
|
||||
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||
wav_frames.put_nowait(_pcm)
|
||||
|
||||
|
|
@ -131,7 +160,7 @@ def main():
|
|||
all_attributes,
|
||||
cfg_is_no_prefix=cfg_is_no_prefix,
|
||||
cfg_is_no_text=cfg_is_no_text,
|
||||
on_audio_hook=_on_audio_hook,
|
||||
on_frame=_on_frame,
|
||||
)
|
||||
frames = mx.concat(result.frames, axis=-1)
|
||||
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
||||
|
|
@ -141,16 +170,20 @@ def main():
|
|||
return result
|
||||
|
||||
if args.out == "-":
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
try:
|
||||
pcm_data = wav_frames.get(block=False)
|
||||
outdata[:, 0] = pcm_data
|
||||
except queue.Empty:
|
||||
outdata[:] = 0
|
||||
with sd.OutputStream(samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback):
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
run()
|
||||
time.sleep(3)
|
||||
while True:
|
||||
|
|
@ -158,6 +191,7 @@ def main():
|
|||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
run()
|
||||
frames = []
|
||||
while True:
|
||||
try:
|
||||
|
|
|
|||
317
scripts/tts_mlx_streaming.py
Normal file
317
scripts/tts_mlx_streaming.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "huggingface_hub",
|
||||
# "moshi_mlx==0.2.12",
|
||||
# "numpy",
|
||||
# "sounddevice",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import queue
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
import sounddevice as sd
|
||||
import sphn
|
||||
import typing as tp
|
||||
from moshi_mlx import models
|
||||
from moshi_mlx.models.generate import LmGen
|
||||
from moshi_mlx.client_utils import make_log
|
||||
from moshi_mlx.modules.conditioner import (
|
||||
ConditionAttributes,
|
||||
ConditionTensor,
|
||||
dropout_all_conditions,
|
||||
)
|
||||
from moshi_mlx.utils.sampling import Sampler
|
||||
from moshi_mlx.models.tts import (
|
||||
Entry,
|
||||
DEFAULT_DSM_TTS_REPO,
|
||||
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
TTSModel,
|
||||
script_to_entries,
|
||||
)
|
||||
from moshi_mlx.utils.loaders import hf_get
|
||||
|
||||
|
||||
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
|
||||
multi_speaker = first_turn and model.multi_speaker
|
||||
return script_to_entries(
|
||||
model.tokenizer,
|
||||
model.machine.token_ids,
|
||||
model.mimi.frame_rate,
|
||||
[script],
|
||||
multi_speaker=multi_speaker,
|
||||
padding_between=1,
|
||||
)
|
||||
|
||||
|
||||
def _make_null(
|
||||
all_attributes: tp.Sequence[ConditionAttributes],
|
||||
) -> list[ConditionAttributes]:
|
||||
# When using CFG, returns the null conditions.
|
||||
return dropout_all_conditions(all_attributes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSGen:
|
||||
tts_model: TTSModel
|
||||
attributes: tp.Sequence[ConditionAttributes]
|
||||
on_frame: tp.Optional[tp.Callable[[mx.array], None]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
tts_model = self.tts_model
|
||||
attributes = self.attributes
|
||||
self.offset = 0
|
||||
self.state = self.tts_model.machine.new_state([])
|
||||
|
||||
if tts_model.cfg_coef != 1.0:
|
||||
if tts_model.valid_cfg_conditionings:
|
||||
raise ValueError(
|
||||
"This model does not support direct CFG, but was trained with "
|
||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||
)
|
||||
nulled = _make_null(attributes)
|
||||
attributes = list(attributes) + nulled
|
||||
|
||||
assert tts_model.lm.condition_provider is not None
|
||||
self.ct = None
|
||||
self.cross_attention_src = None
|
||||
for _attr in attributes:
|
||||
for _key, _value in _attr.text.items():
|
||||
_ct = tts_model.lm.condition_provider.condition_tensor(_key, _value)
|
||||
if self.ct is None:
|
||||
self.ct = _ct
|
||||
else:
|
||||
self.ct = ConditionTensor(self.ct.tensor + _ct.tensor)
|
||||
for _key, _value in _attr.tensor.items():
|
||||
_conditioner = tts_model.lm.condition_provider.conditioners[_key]
|
||||
_ca_src = _conditioner.condition(_value)
|
||||
if self.cross_attention_src is None:
|
||||
self.cross_attention_src = _ca_src
|
||||
else:
|
||||
raise ValueError("multiple cross-attention conditioners")
|
||||
|
||||
def _on_audio_hook(audio_tokens):
|
||||
delays = tts_model.lm.delays
|
||||
for q in range(audio_tokens.shape[0]):
|
||||
delay = delays[q]
|
||||
if self.offset < delay + tts_model.delay_steps:
|
||||
audio_tokens[q] = tts_model.machine.token_ids.zero
|
||||
|
||||
def _on_text_hook(text_tokens):
|
||||
tokens = text_tokens.tolist()
|
||||
out_tokens = []
|
||||
for token in tokens:
|
||||
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
||||
out_tokens.append(out_token)
|
||||
text_tokens[:] = mx.array(out_tokens, dtype=mx.int64)
|
||||
|
||||
self.lm_gen = LmGen(
|
||||
tts_model.lm,
|
||||
max_steps=tts_model.max_gen_length,
|
||||
text_sampler=Sampler(temp=tts_model.temp),
|
||||
audio_sampler=Sampler(temp=tts_model.temp),
|
||||
cfg_coef=tts_model.cfg_coef,
|
||||
on_text_hook=_on_text_hook,
|
||||
on_audio_hook=_on_audio_hook,
|
||||
# TODO(laurent):
|
||||
# cfg_is_masked_until=cfg_is_masked_until,
|
||||
# cfg_is_no_text=cfg_is_no_text,
|
||||
)
|
||||
|
||||
def process_last(self):
|
||||
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||
self._step()
|
||||
additional_steps = (
|
||||
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
|
||||
)
|
||||
for _ in range(additional_steps):
|
||||
self._step()
|
||||
|
||||
def process(self):
|
||||
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
||||
self._step()
|
||||
|
||||
def _step(self):
|
||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||
input_tokens = (
|
||||
mx.ones((1, missing), dtype=mx.int64)
|
||||
* self.tts_model.machine.token_ids.zero
|
||||
)
|
||||
self.lm_gen.step(
|
||||
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
|
||||
)
|
||||
frame = self.lm_gen.last_audio_tokens()
|
||||
self.offset += 1
|
||||
if frame is not None:
|
||||
if self.on_frame is not None:
|
||||
self.on_frame(frame)
|
||||
|
||||
def append_entry(self, entry):
|
||||
self.state.entries.append(entry)
|
||||
|
||||
|
||||
def log(level: str, msg: str):
|
||||
print(make_log(level, msg))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the MLX implementation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo",
|
||||
type=str,
|
||||
default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice-repo",
|
||||
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=int,
|
||||
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(299792458)
|
||||
|
||||
log("info", "retrieving checkpoints")
|
||||
|
||||
raw_config = hf_get("config.json", args.hf_repo)
|
||||
with open(hf_get(raw_config), "r") as fobj:
|
||||
raw_config = json.load(fobj)
|
||||
|
||||
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
|
||||
moshi_name = raw_config.get("moshi_name", "model.safetensors")
|
||||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||
# The following line gets around it for now.
|
||||
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
|
||||
log("info", f"loading model weights from {moshi_weights}")
|
||||
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
|
||||
|
||||
if args.quantize is not None:
|
||||
log("info", f"quantizing model to {args.quantize} bits")
|
||||
nn.quantize(model.depformer, bits=args.quantize)
|
||||
for layer in model.transformer.layers:
|
||||
nn.quantize(layer.self_attn, bits=args.quantize)
|
||||
nn.quantize(layer.gating, bits=args.quantize)
|
||||
|
||||
log("info", f"loading the text tokenizer from {tokenizer}")
|
||||
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
|
||||
|
||||
log("info", f"loading the audio tokenizer {mimi_weights}")
|
||||
generated_codebooks = lm_config.generated_codebooks
|
||||
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
|
||||
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
|
||||
|
||||
cfg_coef_conditioning = None
|
||||
tts_model = TTSModel(
|
||||
model,
|
||||
audio_tokenizer,
|
||||
text_tokenizer,
|
||||
voice_repo=args.voice_repo,
|
||||
temp=0.6,
|
||||
cfg_coef=1,
|
||||
max_padding=8,
|
||||
initial_padding=2,
|
||||
final_padding=2,
|
||||
padding_bonus=0,
|
||||
raw_config=raw_config,
|
||||
)
|
||||
if tts_model.valid_cfg_conditionings:
|
||||
# Model was trained with CFG distillation.
|
||||
cfg_coef_conditioning = tts_model.cfg_coef
|
||||
tts_model.cfg_coef = 1.0
|
||||
mimi = tts_model.mimi
|
||||
|
||||
log("info", "reading input from stdin")
|
||||
|
||||
if tts_model.multi_speaker:
|
||||
voices = [tts_model.get_voice_path(args.voice)]
|
||||
else:
|
||||
voices = []
|
||||
all_attributes = [
|
||||
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||
]
|
||||
|
||||
wav_frames = queue.Queue()
|
||||
|
||||
def _on_frame(frame):
|
||||
if (frame == -1).any():
|
||||
return
|
||||
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
|
||||
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||
wav_frames.put_nowait(_pcm)
|
||||
|
||||
gen = TTSGen(tts_model, all_attributes, on_frame=_on_frame)
|
||||
|
||||
def run():
|
||||
log("info", "starting the inference loop")
|
||||
first_turn = True
|
||||
for line in sys.stdin:
|
||||
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
|
||||
first_turn = False
|
||||
for entry in entries:
|
||||
gen.append_entry(entry)
|
||||
gen.process()
|
||||
gen.process_last()
|
||||
|
||||
if args.out == "-":
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
try:
|
||||
pcm_data = wav_frames.get(block=False)
|
||||
outdata[:, 0] = pcm_data
|
||||
except queue.Empty:
|
||||
outdata[:] = 0
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
run()
|
||||
while True:
|
||||
if wav_frames.qsize() == 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
run()
|
||||
frames = []
|
||||
while True:
|
||||
try:
|
||||
frames.append(wav_frames.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
wav = np.concat(frames, -1)
|
||||
sphn.write_wav(args.out, wav, mimi.sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
140
scripts/tts_pytorch.py
Normal file
140
scripts/tts_pytorch.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "moshi==0.2.11",
|
||||
# "torch",
|
||||
# "sphn",
|
||||
# "sounddevice",
|
||||
# ]
|
||||
# ///
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import queue
|
||||
import sphn
|
||||
import time
|
||||
import torch
|
||||
from moshi.models.loaders import CheckpointInfo
|
||||
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the PyTorch implementation"
|
||||
)
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo",
|
||||
type=str,
|
||||
default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice-repo",
|
||||
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice",
|
||||
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||
help="The voice to use, relative to the voice repo root. "
|
||||
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device on which to run, defaults to 'cuda'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model...")
|
||||
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||
tts_model = TTSModel.from_checkpoint_info(
|
||||
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
||||
)
|
||||
|
||||
if args.inp == "-":
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
text = sys.stdin.read().strip()
|
||||
else:
|
||||
with open(args.inp, "r") as fobj:
|
||||
text = fobj.read().strip()
|
||||
|
||||
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
|
||||
entries = tts_model.prepare_script([text], padding_between=1)
|
||||
if args.voice.endswith(".safetensors"):
|
||||
voice_path = args.voice
|
||||
else:
|
||||
voice_path = tts_model.get_voice_path(args.voice)
|
||||
# CFG coef goes here because the model was trained with CFG distillation,
|
||||
# so it's not _actually_ doing CFG at inference time.
|
||||
# Also, if you are generating a dialog, you should have two voices in the list.
|
||||
condition_attributes = tts_model.make_condition_attributes(
|
||||
[voice_path], cfg_coef=2.0
|
||||
)
|
||||
_frames_cnt = 0
|
||||
|
||||
if args.out == "-":
|
||||
# Stream the audio to the speakers using sounddevice.
|
||||
import sounddevice as sd
|
||||
|
||||
pcms = queue.Queue()
|
||||
|
||||
def _on_frame(frame):
|
||||
nonlocal _frames_cnt
|
||||
if (frame != -1).all():
|
||||
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
||||
_frames_cnt += 1
|
||||
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
try:
|
||||
pcm_data = pcms.get(block=False)
|
||||
outdata[:, 0] = pcm_data
|
||||
except queue.Empty:
|
||||
outdata[:] = 0
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=tts_model.mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
with tts_model.mimi.streaming(1):
|
||||
tts_model.generate(
|
||||
[entries], [condition_attributes], on_frame=_on_frame
|
||||
)
|
||||
time.sleep(3)
|
||||
while True:
|
||||
if pcms.qsize() == 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
|
||||
def _on_frame(frame):
|
||||
nonlocal _frames_cnt
|
||||
if (frame != -1).all():
|
||||
_frames_cnt += 1
|
||||
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||
|
||||
result = tts_model.generate(
|
||||
[entries], [condition_attributes], on_frame=_on_frame
|
||||
)
|
||||
with tts_model.mimi.streaming(1), torch.no_grad():
|
||||
pcms = []
|
||||
for frame in result.frames[tts_model.delay_steps :]:
|
||||
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||
pcms.append(np.clip(pcm[0, 0], -1, 1))
|
||||
pcm = np.concatenate(pcms, axis=-1)
|
||||
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
261
scripts/tts_pytorch_streaming.py
Normal file
261
scripts/tts_pytorch_streaming.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "moshi==0.2.11",
|
||||
# "torch",
|
||||
# "sphn",
|
||||
# "sounddevice",
|
||||
# ]
|
||||
# ///
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import queue
|
||||
import sphn
|
||||
import time
|
||||
import torch
|
||||
import typing as tp
|
||||
from moshi.models.loaders import CheckpointInfo
|
||||
from moshi.conditioners import dropout_all_conditions
|
||||
from moshi.models.lm import LMGen
|
||||
from moshi.models.tts import (
|
||||
Entry,
|
||||
DEFAULT_DSM_TTS_REPO,
|
||||
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
TTSModel,
|
||||
ConditionAttributes,
|
||||
script_to_entries,
|
||||
)
|
||||
|
||||
|
||||
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
|
||||
multi_speaker = first_turn and model.multi_speaker
|
||||
return script_to_entries(
|
||||
model.tokenizer,
|
||||
model.machine.token_ids,
|
||||
model.mimi.frame_rate,
|
||||
[script],
|
||||
multi_speaker=multi_speaker,
|
||||
padding_between=1,
|
||||
)
|
||||
|
||||
|
||||
def _make_null(
|
||||
all_attributes: tp.Sequence[ConditionAttributes],
|
||||
) -> list[ConditionAttributes]:
|
||||
# When using CFG, returns the null conditions.
|
||||
return dropout_all_conditions(all_attributes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSGen:
|
||||
tts_model: TTSModel
|
||||
attributes: tp.Sequence[ConditionAttributes]
|
||||
on_frame: tp.Optional[tp.Callable[[torch.Tensor], None]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
tts_model = self.tts_model
|
||||
attributes = self.attributes
|
||||
self.offset = 0
|
||||
self.state = self.tts_model.machine.new_state([])
|
||||
if tts_model.cfg_coef != 1.0:
|
||||
if tts_model.valid_cfg_conditionings:
|
||||
raise ValueError(
|
||||
"This model does not support direct CFG, but was trained with "
|
||||
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||
)
|
||||
nulled = _make_null(attributes)
|
||||
attributes = list(attributes) + nulled
|
||||
|
||||
assert tts_model.lm.condition_provider is not None
|
||||
prepared = tts_model.lm.condition_provider.prepare(attributes)
|
||||
condition_tensors = tts_model.lm.condition_provider(prepared)
|
||||
|
||||
def _on_text_logits_hook(text_logits):
|
||||
if tts_model.padding_bonus:
|
||||
text_logits[..., tts_model.machine.token_ids.pad] += (
|
||||
tts_model.padding_bonus
|
||||
)
|
||||
return text_logits
|
||||
|
||||
def _on_audio_hook(audio_tokens):
|
||||
audio_offset = tts_model.lm.audio_offset
|
||||
delays = tts_model.lm.delays
|
||||
for q in range(audio_tokens.shape[1]):
|
||||
delay = delays[q + audio_offset]
|
||||
if self.offset < delay + tts_model.delay_steps:
|
||||
audio_tokens[:, q] = tts_model.machine.token_ids.zero
|
||||
|
||||
def _on_text_hook(text_tokens):
|
||||
tokens = text_tokens.tolist()
|
||||
out_tokens = []
|
||||
for token in tokens:
|
||||
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
||||
out_tokens.append(out_token)
|
||||
text_tokens[:] = torch.tensor(
|
||||
out_tokens, dtype=torch.long, device=text_tokens.device
|
||||
)
|
||||
|
||||
tts_model.lm.dep_q = tts_model.n_q
|
||||
self.lm_gen = LMGen(
|
||||
tts_model.lm,
|
||||
temp=tts_model.temp,
|
||||
temp_text=tts_model.temp,
|
||||
cfg_coef=tts_model.cfg_coef,
|
||||
condition_tensors=condition_tensors,
|
||||
on_text_logits_hook=_on_text_logits_hook,
|
||||
on_text_hook=_on_text_hook,
|
||||
on_audio_hook=_on_audio_hook,
|
||||
cfg_is_masked_until=None,
|
||||
cfg_is_no_text=True,
|
||||
)
|
||||
self.lm_gen.streaming_forever(1)
|
||||
|
||||
def process_last(self):
|
||||
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||
self._step()
|
||||
additional_steps = (
|
||||
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
|
||||
)
|
||||
for _ in range(additional_steps):
|
||||
self._step()
|
||||
|
||||
def process(self):
|
||||
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
||||
self._step()
|
||||
|
||||
def _step(self):
|
||||
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||
input_tokens = torch.full(
|
||||
(1, missing, 1),
|
||||
self.tts_model.machine.token_ids.zero,
|
||||
dtype=torch.long,
|
||||
device=self.tts_model.lm.device,
|
||||
)
|
||||
frame = self.lm_gen.step(input_tokens)
|
||||
self.offset += 1
|
||||
if frame is not None:
|
||||
if self.on_frame is not None:
|
||||
self.on_frame(frame)
|
||||
|
||||
def append_entry(self, entry):
|
||||
self.state.entries.append(entry)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Kyutai TTS using the PyTorch implementation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo",
|
||||
type=str,
|
||||
default=DEFAULT_DSM_TTS_REPO,
|
||||
help="HF repo in which to look for the pretrained models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice-repo",
|
||||
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice",
|
||||
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||
help="The voice to use, relative to the voice repo root. "
|
||||
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device on which to run, defaults to 'cuda'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model...")
|
||||
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||
tts_model = TTSModel.from_checkpoint_info(
|
||||
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
||||
)
|
||||
|
||||
if args.voice.endswith(".safetensors"):
|
||||
voice_path = args.voice
|
||||
else:
|
||||
voice_path = tts_model.get_voice_path(args.voice)
|
||||
# CFG coef goes here because the model was trained with CFG distillation,
|
||||
# so it's not _actually_ doing CFG at inference time.
|
||||
# Also, if you are generating a dialog, you should have two voices in the list.
|
||||
condition_attributes = tts_model.make_condition_attributes(
|
||||
[voice_path], cfg_coef=2.0
|
||||
)
|
||||
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
|
||||
if args.out == "-":
|
||||
# Stream the audio to the speakers using sounddevice.
|
||||
import sounddevice as sd
|
||||
|
||||
pcms = queue.Queue()
|
||||
|
||||
def _on_frame(frame):
|
||||
if (frame != -1).all():
|
||||
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
try:
|
||||
pcm_data = pcms.get(block=False)
|
||||
outdata[:, 0] = pcm_data
|
||||
except queue.Empty:
|
||||
outdata[:] = 0
|
||||
|
||||
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=tts_model.mimi.sample_rate,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
) and tts_model.mimi.streaming(1):
|
||||
first_turn = True
|
||||
for line in sys.stdin:
|
||||
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
|
||||
first_turn = False
|
||||
for entry in entries:
|
||||
gen.append_entry(entry)
|
||||
gen.process()
|
||||
gen.process_last()
|
||||
while True:
|
||||
if pcms.qsize() == 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
pcms = []
|
||||
|
||||
def _on_frame(frame: torch.Tensor):
|
||||
if (frame != -1).all():
|
||||
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||
pcms.append(np.clip(pcm[0, 0]))
|
||||
|
||||
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
|
||||
with tts_model.mimi.streaming(1):
|
||||
first_turn = True
|
||||
for line in sys.stdin:
|
||||
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
|
||||
first_turn = False
|
||||
for entry in entries:
|
||||
gen.append_entry(entry)
|
||||
gen.process()
|
||||
gen.process_last()
|
||||
pcm = np.concatenate(pcms, axis=-1)
|
||||
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
178
scripts/tts_rust_server.py
Normal file
178
scripts/tts_rust_server.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "msgpack",
|
||||
# "numpy",
|
||||
# "sphn",
|
||||
# "websockets",
|
||||
# "sounddevice",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import sphn
|
||||
import tqdm
|
||||
import websockets
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
|
||||
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
|
||||
AUTH_TOKEN = "public_token"
|
||||
|
||||
|
||||
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
|
||||
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
|
||||
accumulated_samples = 0
|
||||
last_seconds = 0
|
||||
|
||||
async for message_bytes in websocket:
|
||||
msg = msgpack.unpackb(message_bytes)
|
||||
|
||||
if msg["type"] == "Audio":
|
||||
pcm = np.array(msg["pcm"]).astype(np.float32)
|
||||
await output_queue.put(pcm)
|
||||
|
||||
accumulated_samples += len(msg["pcm"])
|
||||
current_seconds = accumulated_samples // SAMPLE_RATE
|
||||
if current_seconds > last_seconds:
|
||||
pbar.update(current_seconds - last_seconds)
|
||||
last_seconds = current_seconds
|
||||
|
||||
print("End of audio.")
|
||||
await output_queue.put(None) # Signal end of audio
|
||||
|
||||
|
||||
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
|
||||
if out == "-":
|
||||
should_exit = False
|
||||
|
||||
def audio_callback(outdata, _a, _b, _c):
|
||||
nonlocal should_exit
|
||||
|
||||
try:
|
||||
pcm_data = output_queue.get_nowait()
|
||||
if pcm_data is not None:
|
||||
outdata[:, 0] = pcm_data
|
||||
else:
|
||||
should_exit = True
|
||||
outdata[:] = 0
|
||||
except asyncio.QueueEmpty:
|
||||
outdata[:] = 0
|
||||
|
||||
with sd.OutputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
blocksize=1920,
|
||||
channels=1,
|
||||
callback=audio_callback,
|
||||
):
|
||||
while True:
|
||||
if should_exit:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
frames = []
|
||||
while True:
|
||||
item = await output_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
frames.append(item)
|
||||
|
||||
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
|
||||
print(f"Saved audio to {out}")
|
||||
|
||||
|
||||
async def read_lines_from_stdin():
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
break
|
||||
yield line.decode().rstrip()
|
||||
|
||||
|
||||
async def read_lines_from_file(path: str):
|
||||
queue = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def producer():
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
|
||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||
|
||||
await asyncio.to_thread(producer)
|
||||
while True:
|
||||
line = await queue.get()
|
||||
if line is None:
|
||||
break
|
||||
yield line
|
||||
|
||||
|
||||
async def get_lines(source: str):
|
||||
if source == "-":
|
||||
async for line in read_lines_from_stdin():
|
||||
yield line
|
||||
else:
|
||||
async for line in read_lines_from_file(source):
|
||||
yield line
|
||||
|
||||
|
||||
async def websocket_client():
|
||||
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||
parser.add_argument(
|
||||
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--voice",
|
||||
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||
help="The voice to use, relative to the voice repo root. "
|
||||
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
help="The URL of the server to which to send the audio",
|
||||
default="ws://127.0.0.1:8080",
|
||||
)
|
||||
parser.add_argument("--api-key", default="public_token")
|
||||
args = parser.parse_args()
|
||||
|
||||
params = {"voice": args.voice, "format": "PcmMessagePack"}
|
||||
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||
print(uri)
|
||||
|
||||
if args.inp == "-":
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
headers = {"kyutai-api-key": args.api_key}
|
||||
|
||||
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||
print("connected")
|
||||
|
||||
async def send_loop():
|
||||
print("go send")
|
||||
async for line in get_lines(args.inp):
|
||||
for word in line.split():
|
||||
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
|
||||
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||
|
||||
output_queue = asyncio.Queue()
|
||||
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||
send_task = asyncio.create_task(send_loop())
|
||||
await asyncio.gather(receive_task, output_audio_task, send_task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(websocket_client())
|
||||
|
|
@ -228,11 +228,9 @@
|
|||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
140
tts_pytorch.ipynb
Normal file
140
tts_pytorch.ipynb
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Fast install, might break in the future.\n",
|
||||
"!pip install 'sphn<0.2'\n",
|
||||
"!pip install --no-deps \"moshi==0.2.11\"\n",
|
||||
"# Slow install (will download torch and cuda), but future proof.\n",
|
||||
"# !pip install \"moshi==0.2.11\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import argparse\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"from moshi.models.loaders import CheckpointInfo\n",
|
||||
"from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
|
||||
"\n",
|
||||
"from IPython.display import display, Audio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration\n",
|
||||
"text = \"Hey there! How are you? I had the craziest day today.\"\n",
|
||||
"voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n",
|
||||
"print(f\"See https://huggingface.co/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set everything up\n",
|
||||
"checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
|
||||
"tts_model = TTSModel.from_checkpoint_info(\n",
|
||||
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
|
||||
"entries = tts_model.prepare_script([text], padding_between=1)\n",
|
||||
"voice_path = tts_model.get_voice_path(voice)\n",
|
||||
"# CFG coef goes here because the model was trained with CFG distillation,\n",
|
||||
"# so it's not _actually_ doing CFG at inference time.\n",
|
||||
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
|
||||
"condition_attributes = tts_model.make_condition_attributes(\n",
|
||||
" [voice_path], cfg_coef=2.0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"Generating audio...\")\n",
|
||||
"\n",
|
||||
"pcms = []\n",
|
||||
"def _on_frame(frame):\n",
|
||||
" print(\"Step\", len(pcms), end=\"\\r\")\n",
|
||||
" if (frame != -1).all():\n",
|
||||
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
|
||||
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
|
||||
"\n",
|
||||
"# You could also generate multiple audios at once by extending the following lists.\n",
|
||||
"all_entries = [entries]\n",
|
||||
"all_condition_attributes = [condition_attributes]\n",
|
||||
"with tts_model.mimi.streaming(len(all_entries)):\n",
|
||||
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
|
||||
"\n",
|
||||
"print(\"Done generating.\")\n",
|
||||
"audio = np.concatenate(pcms, axis=-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"display(\n",
|
||||
" Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user