Compare commits

..

5 Commits

Author SHA1 Message Date
gabrieldemarmiesse
403db09953 Remove empty space 2025-06-18 10:41:56 +00:00
gabrieldemarmiesse
332b2b9daa Clarify real-time 2025-06-18 10:40:26 +00:00
gabrieldemarmiesse
7c9953187a Merge branch 'main' into give_uv_instructions 2025-06-18 10:36:57 +00:00
gabrieldemarmiesse
6247aee904 Add french sample 2025-06-18 10:33:42 +00:00
gabrieldemarmiesse
e202e4bb0a Add uv instructions and ignore the sample audio file 2025-06-18 10:32:31 +00:00
39 changed files with 333 additions and 4512 deletions

View File

@ -1,83 +0,0 @@
name: Bug Report
description: You found a bug.
labels: ["bug", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
- type: dropdown
id: backend
attributes:
label: Backend impacted
description: Which backend is concerned with your bug report?
options:
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
description: What is your operating system?
options:
- Linux
- Mac OS X
- Windows (unsupported)
default: 0
validations:
required: true
- type: dropdown
id: hardware
attributes:
label: Hardware
description: What hardware are you using?
options:
- CPU
- GPU with CUDA
- Metal with MLX
default: 0
validations:
required: true
- type: textarea
id: description
attributes:
label: Description
description: Provide a detailed description of your bug.
placeholder:
value:
validations:
required: true
- type: textarea
id: more_info
attributes:
label: Extra information
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value:
validations:
required: true
- type: textarea
id: env
attributes:
label: Environment
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value: |
Fill in the following information on your system.
- Operating system version:
If the backend impacted is PyTorch:
- Python version:
- PyTorch version:
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
- GPU model and memory:
If the backend is MLX:
- Mac model:
validations:
required: true

View File

@ -1,40 +0,0 @@
name: Question
description: You have a question about the codebase, the paper, or the implementation.
labels: ["question", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
- type: checkboxes
id: terms
attributes:
label: Due diligence
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
options:
- label: I have done my due diligence in trying to find the answer myself.
required: true
- type: dropdown
id: backend
attributes:
label: Topic
description: What is your question about?
options:
- The paper
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: textarea
id: question
attributes:
label: Question
description: What is your question?
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
value:
validations:
required: true

View File

@ -1,9 +0,0 @@
## Checklist
- [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
- [ ] Run pre-commit hook.
- [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
## PR Description
<!-- Description for the PR -->

View File

@ -1,28 +0,0 @@
name: moshi_build
description: 'Build env.'
runs:
using: "composite"
steps:
- uses: actions/setup-python@v2
with:
python-version: '3.10.14'
- uses: actions/cache@v3
id: cache
with:
path: env
key: env-${{ hashFiles('moshi/pyproject.toml') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
shell: bash
run: |
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install moshi==0.2.7
pip install pre-commit
- name: Setup env
shell: bash
run: |
source env/bin/activate
pre-commit install

View File

@ -1,17 +0,0 @@
name: precommit
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
run_precommit:
name: Run precommit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/moshi_build
- run: |
source env/bin/activate
pre-commit run --all-files

3
.gitignore vendored
View File

@ -192,4 +192,5 @@ cython_debug/
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
out*.wav
bria.mp3
sample_fr_hibiki_crepes.mp3

View File

@ -1,22 +0,0 @@
repos:
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ["--maxkb=2048"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
types_or: [python, pyi] # Don't run on `jupyter` files
args: [--fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi] # Don't run on `jupyter` files

View File

@ -1,3 +0,0 @@
{
"python.analysis.typeCheckingMode": "standard"
}

View File

@ -1,58 +0,0 @@
# Contributing to Delayed-Streams-Modeling
## Pull Requests
Delayed-Streams-Modeling is the implementation of a research paper.
Therefore, we do not plan on accepting many pull requests for new features.
However, we certainly welcome them for bug fixes.
1. Fork the repo and create your branch from `main`.
2. If you have changed APIs, update the documentation accordingly.
3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
5. Accept the Contributor License Agreement (see after).
Note that in general, we will not accept refactoring of the code.
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a Contributor License Agreement.
If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle:
> I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
The full CLA is provided as follows:
> I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
> irrevocable license to use, modify, distribute, and sublicense my Contributions.
> I understand and accept that Contributions are limited to modifications, improvements, or changes
> to the projects source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
> review, accept, reject, or request changes to any Contributions I submit, and that submitting
> a pull request does not guarantee its inclusion in the project.
> By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
> reproduce, distribute, and create derivative works based on my Contributions.
> I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
> giving the Kyutai-labs full rights to file for and enforce patents.
> I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
> I confirm that my Contributions are original and that I have the legal right to grant this license.
> If my Contributions include third-party materials, I will ensure that I have the necessary permissions
> and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labss discretion.
> I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
> Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
> By submitting a pull request, I agree to be bound by these terms.
## Issues
Please submit issues on our Github repository.
## License
By contributing to Delayed-Streams-Modeling, you agree that your contributions
will be licensed under the LICENSE-* files in the root directory of this source
tree. In particular, the rust code is licensed under APACHE, and the python code
under MIT.

56
FAQ.md
View File

@ -1,56 +0,0 @@
# FAQ
Here is the answer to a number of frequently asked questions.
### Torch compilation issues
With some PyTorch/triton versions, one might encounter compilation errors
like the following:
```
Traceback (most recent call last):
...
File "site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1153, in make_launcher
"launch_enter_hook": binary.__class__.launch_enter_hook,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
```
If that's the case, you can disable torch compilation by setting the following
environment variable.
```bash
export NO_TORCH_COMPILE=1
```
### Issues installing the sentencepiece dependency
On some linux distributions (arch) or on macos, the local version of cmake can
be too recent for the sentencepiece dependency.
```
CMake Error at CMakeLists.txt:15 (cmake_minimum_required):
Compatibility with CMake < 3.5 has been removed from CMake.
```
You can either downgrade your cmake version, e.g. 3.31.0 on arch works or try
setting `CMAKE_POLICY_VERSION_MINIMUM=3.5`.
If you run into some errors when compiling the sentencepiece rust bindings,
these could also be due to gcc being too recent, e.g. gcc 15. You can get
around this by using gcc-13, e.g. by setting the following after installing
the proper gcc packages.
```bash
export CMAKE_C_COMPILER=/usr/bin/gcc-13
export CMAKE_CXX_COMPILER=/usr/bin/g++-13
CC=gcc-13 CXX=g++-13 cargo build --release
```
Alternatively you can set `CXXFLAGS="-include cstdint"`, see this
[issue](https://github.com/google/sentencepiece/issues/1108).
### Will you release training code?
Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
This code has not been adapted to the Speech-To-Text and Text-To-Speech models
yet, but it should be a good starting point.

320
README.md
View File

@ -1,123 +1,84 @@
# Delayed Streams Modeling: Kyutai STT & TTS
# delayed-streams-modeling
Delayed Streams Modeling (DSM) is a flexible formulation for streaming, multimodal sequence-to-sequence learning.
This repo contains instructions and examples of how to run
[Kyutai Speech-To-Text](#kyutai-speech-to-text)
and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models.
These models are powered by delayed streams modeling (DSM),
a flexible formulation for streaming, multimodal sequence-to-sequence learning.
See also [Unmute](https://github.com/kyutai-labs/unmute), an voice AI system built using Kyutai STT and Kyutai TTS.
## Speech-to-text
But wait, what is "Delayed Streams Modeling"? It is a technique for solving many streaming X-to-Y tasks (with X, Y in `{speech, text}`)
that formalize the approach we had with Moshi and Hibiki. A pre-print paper is coming soon!
DSM can be used to build streaming speech-to-text models. These models can be
batched for efficiency, return word level timestamps, and are great for
interactive applications. We provide two such models, these models are
characterized by their size as well as the delay it takes for audio to be
transcribed into text. We provide two such models:
- An English and French model with ~1b parameters using a 0.5 second delay,
`kyutai/stt-1b-en_fr`.
- An English only model with ~2.6b parameters using a 2.5 second delay,
`kyutai/stt-2.6b-en`.
## Kyutai Speech-To-Text
More details can be found on the [project page](https://kyutai.org/next/stt).
<a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
You can retrieve the sample files used in the following snippets via:
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3
```
**More details can be found on the [project page](https://kyutai.org/next/stt).**
Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps.
We provide two models:
- `kyutai/stt-1b-en_fr`, an English and French model with ~1B parameters, a 0.5 second delay, and a [semantic VAD](https://kyutai.org/next/stt#semantic-vad).
- `kyutai/stt-2.6b-en`, an English-only model with ~2.6B parameters and a 2.5 second delay.
These speech-to-text models have several advantages:
- Streaming inference: the models can process audio in chunks, which allows
for real-time transcription, and is great for interactive applications.
- Easy batching for maximum efficiency: a H100 can process 400 streams in
real-time.
- They return word-level timestamps.
- The 1B model has a semantic Voice Activity Detection (VAD) component that
can be used to detect when the user is speaking. This is especially useful
for building voice agents.
### Implementations overview
We provide different implementations of Kyutai STT for different use cases.
Here is how to choose which one to use:
- **PyTorch: for research and tinkering.**
If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
- **Rust: for production.**
If you want to serve Kyutai STT in a production setting, use our Rust server.
Our robust Rust server provides streaming access to the model over websockets.
We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x.
- **MLX: for on-device inference on iPhone and Mac.**
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
<details>
<summary>PyTorch implementation</summary>
### PyTorch implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb).
This requires the [moshi package](https://pypi.org/project/moshi/)
with version 0.2.6 or later, which can be installed via pip.
If you just want to run the model on a file, you can use `moshi.run_inference`.
with version 0.2.5 or later, which can be installed via pip.
```bash
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
```
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
and just prefix the command above with `uvx --with moshi`.
If you have `uv` installed, you can skip the installation step and run directly:
```bash
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
```
It will install the moshi package in a temporary environment and run the speech-to-text.
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
### MLX implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
with version 0.2.5 or later, which can be installed via pip.
```bash
uv run \
scripts/stt_from_file_pytorch.py \
--hf-repo kyutai/stt-2.6b-en \
audio/bria.mp3
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
```
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
If you have `uv` installed, you can skip the installation step and run directly:
```bash
uv run scripts/evaluate_on_dataset.py \
--dataset meanwhile \
--hf-repo kyutai/stt-2.6b-en
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
```
It will install the moshi package in a temporary environment and run the speech-to-text.
Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model:
### Rust implementation
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
A standalone Rust example is provided in the `stt-rs` directory in this repo.
This can be used as follows:
```bash
uv run scripts/stt_from_file_pytorch_with_prompt.py \
--hf-repo kyutai/stt-2.6b-en \
--file bria.mp3 \
--prompt_file ./audio/loonah.mp3 \
--prompt_text "Loonah" \
--cut-prompt-transcript
```
Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt:
```
In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...)
cd stt-rs
cargo run --features cuda -r -- bria.mp3
```
Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language.
However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided.
</details>
<details>
<summary>Rust server</summary>
### Rust server
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
The Rust implementation provides a server that can process multiple streaming
queries in parallel. Depending on the amount of memory on your GPU, you may
queries in parallel. Dependening on the amount of memory on your GPU, you may
have to adjust the batch size from the config file. For a L40S GPU, a batch size
of 64 works well and requests can be processed at 3x real-time speed.
@ -132,182 +93,24 @@ cargo install --features cuda moshi-server
Then the server can be started via the following command using the config file
from this repository.
For `kyutai/stt-1b-en_fr`, use `configs/config-stt-en_fr.hf.toml`,
and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`,
```bash
moshi-server worker --config configs/config-stt-en_fr-hf.toml
moshi-server worker --config configs/config-stt-hf.toml
```
Once the server has started you can transcribe audio from your microphone with the following script.
Once the server has started you can run a streaming inference with the following
script.
```bash
uv run scripts/stt_from_mic_rust_server.py
```
We also provide a script for transcribing from an audio file.
```bash
uv run scripts/stt_from_file_rust_server.py audio/bria.mp3
uv run scripts/asr-streaming-query.py bria.mp3
```
The script limits the decoding speed to simulates real-time processing of the audio.
Faster processing can be triggered by setting
the real-time factor, e.g. `--rtf 1000` will process
the real-time factor, e.g. `--rtf 500` will process
the data as fast as possible.
</details>
<details>
<summary>Rust standalone</summary>
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
## Text-to-Speech
A standalone Rust example script is provided in the `stt-rs` directory in this repo.
This can be used as follows:
```bash
cd stt-rs
cargo run --features cuda -r -- ../audio/bria.mp3
```
You can get the timestamps by adding the `--timestamps` flag, and see the output
of the semantic VAD by adding the `--vad` flag.
</details>
<details>
<summary>MLX implementation</summary>
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
hardware acceleration on Apple silicon.
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
with version 0.2.6 or later, which can be installed via pip.
If you just want to run the model on a file, you can use `moshi_mlx.run_inference`:
```bash
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
```
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
and just prefix the command above with `uvx --with moshi-mlx`.
If you want to transcribe audio from your microphone, use:
```bash
python scripts/stt_from_mic_mlx.py
```
The MLX models can also be used in swift using the [moshi-swift
codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been
tested to work fine on an iPhone 16 Pro.
</details>
## Kyutai Text-to-Speech
<a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
**More details can be found on the [project page](https://kyutai.org/next/tts).**
We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use:
- PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
- Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute.
- MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
<details>
<summary>PyTorch implementation</summary>
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script:
```bash
# From stdin, plays audio immediately
echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
# From text file to audio file
python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
```
The `tts_pytorch.py` script waits for all the text to be available before
starting the audio generation. A fully streaming implementation is available in
the `tts_pytorch_streaming.py` script, which can be used as follows:
```bash
echo "Hey, how are you?" | python scripts/tts_pytorch_streaming.py audio_output.wav
```
This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
and just prefix the command above with `uvx --with moshi`.
</details>
<details>
<summary>Rust server</summary>
The Rust implementation provides a server that can process multiple streaming
queries in parallel.
Installing the Rust server is a bit tricky because it uses our Python implementation under the hood,
which also requires installing the Python dependencies.
Use the [start_tts.sh](https://github.com/kyutai-labs/unmute/blob/main/dockerless/start_tts.sh) script to properly install the Rust server.
If you already installed the `moshi-server` crate before and it's not working, you might need to force a reinstall by running `cargo uninstall moshi-server` first.
Feel free to open an issue if the installation is still broken.
Once installed, the server can be started via the following command using the config file
from this repository.
```bash
moshi-server worker --config configs/config-tts.toml
```
Once the server has started you can connect to it using our script as follows:
```bash
# From stdin, plays audio immediately
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
# From text file to audio file
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
```
</details>
<details>
<summary>MLX implementation</summary>
[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
hardware acceleration on Apple silicon.
Use our example script to run Kyutai TTS on MLX.
The script takes text from stdin or a file and can output to a file or stream the resulting audio.
When streaming the output, if the model is not fast enough to keep with
real-time, you can use the `--quantize 8` or `--quantize 4` flags to quantize
the model resulting in faster inference.
```bash
# From stdin, plays audio immediately
echo "Hey, how are you?" | python scripts/tts_mlx.py - - --quantize 8
# From text file to audio file
python scripts/tts_mlx.py text_to_say.txt audio_output.wav
```
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip.
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
and just prefix the command above with `uvx --with moshi-mlx`.
</details>
## FAQ
Checkout the [Frequently Asked Questions](FAQ.md) section before opening an issue.
We're in the process of open-sourcing our TTS models. Check back for updates!
## License
@ -317,14 +120,3 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo
the MIT license.
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
## Developing
Install the [pre-commit hooks](https://pre-commit.com/) by running:
```bash
pip install pre-commit
pre-commit install
```
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.

View File

@ -1,298 +0,0 @@
#!/usr/bin/env python3
"""
OpenAI-Compatible Kyutai TTS API Server with Model Caching
Improved version that loads the model once and keeps it in memory
"""
import os
import io
import time
import asyncio
import subprocess
from pathlib import Path
from typing import Optional, Literal
import logging
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables - loaded once at startup
tts_model = None
device = None
sample_rate = None
class SpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
app = FastAPI(
title="OpenAI-Compatible TTS API (Cached)",
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = Path("/app/api_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def load_tts_model():
"""Load TTS model once at startup and keep in memory"""
global tts_model, device, sample_rate
if tts_model is not None:
logger.info("TTS model already loaded")
return
try:
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
# Import Kyutai TTS modules
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load the TTS model
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info,
n_q=32,
temp=0.6,
device=device
)
# Get sample rate
sample_rate = tts_model.mimi.sample_rate
logger.info(f"✅ TTS model loaded successfully!")
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
logger.info(f" Device: {device}")
logger.info(f" Sample Rate: {sample_rate}")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
"""Generate audio using cached TTS model"""
global tts_model, device, sample_rate
if tts_model is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
try:
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare the script (text input)
entries = tts_model.prepare_script([text], padding_between=1)
# Voice mapping for OpenAI compatibility
voice_mapping = {
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
}
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
try:
voice_path = tts_model.get_voice_path(selected_voice)
except:
# Fallback to default if voice not found
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
# Prepare condition attributes
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
# Generate audio
pcms = []
def on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
# Concatenate all audio frames
if pcms:
import numpy as np
audio = np.concatenate(pcms, axis=-1)
# Apply speed adjustment if needed
if speed != 1.0:
# Simple speed adjustment by resampling
from scipy import signal
audio_length = len(audio)
new_length = int(audio_length / speed)
audio = signal.resample(audio, new_length)
# Convert to bytes
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
audio_bytes.seek(0)
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
return audio_bytes.read()
else:
raise Exception("No audio frames generated")
except Exception as e:
logger.error(f"❌ TTS generation error: {e}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
"""Convert WAV audio to requested format using ffmpeg"""
try:
if output_format == "wav":
return audio_wav_bytes
# Use ffmpeg to convert
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
result = subprocess.run(
cmd,
input=audio_wav_bytes,
capture_output=True,
check=True
)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"Audio conversion failed: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
OpenAI-compatible audio speech endpoint
Uses cached TTS model for fast generation
"""
try:
start_time = time.time()
# Generate audio with cached model
audio_wav_bytes = generate_audio_fast(
text=request.input,
voice=request.voice,
speed=request.speed
)
# Convert to requested format
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
generation_time = time.time() - start_time
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
# Set appropriate content type
content_types = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm"
}
return Response(
content=audio_data,
media_type=content_types.get(request.response_format, "audio/wav"),
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Generation-Time": str(generation_time)
}
)
except Exception as e:
logger.error(f"Speech generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "tts-1",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1",
"parent": None
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1-hd",
"parent": None
}
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint with model status"""
model_loaded = tts_model is not None
return {
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": str(device) if device else None,
"service": "kyutai-tts-openai-compatible-cached"
}
@app.get("/reload-model")
async def reload_model():
"""Reload the TTS model (admin endpoint)"""
global tts_model
try:
tts_model = None
load_tts_model()
return {"status": "success", "message": "Model reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting TTS API server with model caching...")
load_tts_model()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -1,67 +0,0 @@
#!/usr/bin/env python3
"""
Check if all Kyutai TTS dependencies are properly installed
"""
import sys
def check_dependencies():
print("🔍 Checking Kyutai TTS Dependencies")
print("=" * 40)
dependencies = [
"torch",
"numpy",
"einops",
"transformers",
"accelerate",
"soundfile",
"librosa",
"huggingface_hub",
"moshi",
"sphn"
]
missing = []
installed = []
for dep in dependencies:
try:
__import__(dep)
installed.append(dep)
print(f"{dep}")
except ImportError as e:
missing.append((dep, str(e)))
print(f"{dep}: {e}")
print(f"\n📊 Summary:")
print(f"✓ Installed: {len(installed)}")
print(f"✗ Missing: {len(missing)}")
if missing:
print(f"\n🔧 To fix missing dependencies:")
for dep, error in missing:
print(f"pip install {dep}")
print(f"\n🧪 Testing Kyutai TTS imports:")
try:
from moshi.models.loaders import CheckpointInfo
print("✓ CheckpointInfo import successful")
except Exception as e:
print(f"✗ CheckpointInfo import failed: {e}")
try:
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
print("✓ TTSModel imports successful")
except Exception as e:
print(f"✗ TTSModel imports failed: {e}")
return len(missing) == 0
if __name__ == "__main__":
success = check_dependencies()
if success:
print("\n🎉 All dependencies are installed correctly!")
else:
print("\n❌ Some dependencies are missing. Please install them first.")
sys.exit(1)

View File

@ -1,59 +0,0 @@
#!/usr/bin/env python3
"""
Kyutai TTS PyTorch Runner
Dockerized implementation for text-to-speech generation
"""
import sys
import os
import argparse
import torch
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
parser.add_argument('input_file', help='Input text file or "-" for stdin')
parser.add_argument('output_file', help='Output audio file')
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
args = parser.parse_args()
print(f"Using device: {args.device}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Handle stdin input
if args.input_file == '-':
# Read from stdin and create temporary file
text = sys.stdin.read().strip()
temp_file = '/tmp/temp_input.txt'
with open(temp_file, 'w') as f:
f.write(text)
input_file = temp_file
else:
input_file = args.input_file
# Check if the original TTS script exists
tts_script = Path('/app/scripts/tts_pytorch.py')
if tts_script.exists():
print("Using original TTS script from Kyutai repository")
import subprocess
cmd = ['python', str(tts_script), input_file, args.output_file]
subprocess.run(cmd, check=True)
else:
print("Using moshi package for TTS generation")
import subprocess
cmd = [
'python', '-m', 'moshi.run_inference',
'--hf-repo', args.model,
input_file,
args.output_file
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error: {result.stderr}")
sys.exit(1)
print(f"Audio generated: {args.output_file}")
if __name__ == '__main__':
main()
EOF

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,7 +1,7 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["public_token"]
authorized_ids = ["open_token"]
[modules.asr]
path = "/api/asr-streaming"

View File

@ -1,7 +1,7 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["public_token"]
authorized_ids = ["open_token"]
[modules.asr]
path = "/api/asr-streaming"

View File

@ -1,20 +0,0 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["public_token"]
[modules.tts_py]
type = "Py"
path = "/api/tts_streaming"
text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
batch_size = 8 # Adjust to your GPU memory capacity
text_bos_token = 1
[modules.tts_py.py]
log_folder = "$HOME/tmp/moshi-server-logs"
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
default_voice = "unmute-prod-website/default_voice.wav"
cfg_coef = 2.0
cfg_is_no_text = true
padding_between = 1
n_q = 24

View File

@ -1,78 +0,0 @@
# Set environment variables
export DEBIAN_FRONTEND=noninteractive
export PYTHONUNBUFFERED=1
export CUDA_VISIBLE_DEVICES=0
# Install system dependencies
apt-get update && apt-get install -y \
wget \
curl \
git \
build-essential \
libsndfile1 \
ffmpeg \
sox \
alsa-utils \
pulseaudio \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Create virtual environment
apt install python3.12-venv python3.12-dev
python3.12 -m venv ~/venv-tts-kyutai
source ~/venv-tts-kyutai/bin/activate
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Install PyTorch with CUDA support for Python 3.12
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install core dependencies
pip install --no-cache-dir \
numpy \
scipy \
librosa \
soundfile \
huggingface_hub \
einops \
transformers \
accelerate
# Install API dependencies
pip install --no-cache-dir \
fastapi \
uvicorn[standard] \
python-multipart \
pydantic
# Install moshi package with all dependencies (following Colab notebook)
pip install --no-cache-dir 'sphn<0.2'
pip install --no-cache-dir "moshi==0.2.8"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Start TTS-Server
python /app/api_server.py

View File

@ -0,0 +1,131 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# ]
# ///
import argparse
import asyncio
import json
import msgpack
import sphn
import struct
import time
import numpy as np
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
return pcm_data[0]
async def receive_messages(websocket):
global all_text
global transcript
global finished
try:
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
if data["type"] == "Step":
continue
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
transcript.append({
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
})
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
if data["type"] == "Marker":
print("Received marker, stopping stream.")
break
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
finished = True
async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
current_time = time.time()
if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
for _ in range(35):
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
if finished:
break
await asyncio.sleep(1.0)
# Keep the connection alive as there is a 20s timeout on the rust side.
await websocket.ping()
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, rtf: float):
"""Stream audio data to a WebSocket server."""
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
print("exiting")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--rtf", type=float, default=1.01)
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:
json.dump({"transcript": transcript}, fobj, indent=4)

View File

@ -1,387 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "datasets",
# "jiwer==3.1.0",
# "julius",
# "librosa",
# "moshi",
# "openai-whisper",
# "soundfile",
# ]
# ///
"""
Example implementation of the streaming STT example. Here we group
test utterances in batches (pre- and post-padded with silence) and
and then feed these batches into the streaming STT model frame-by-frame.
"""
# The outputs I get on my H100 using this code with the 2.6B model,
# bsz 32:
# LibriVox === cer: 4.09% wer: 7.33% corpus_wer: 6.78% RTF = 52.72
# Ami === cer: 15.99% wer: 18.78% corpus_wer: 12.20% RTF = 28.37
# LibriSpeech other === cer: 2.31% wer: 5.24% corpus_wer: 4.33% RTF = 44.76
# LibriSpeech clean === cer: 0.67% wer: 1.95% corpus_wer: 1.69% RTF = 68.19
# Tedlium (short) === cer: 2.15% wer: 3.65% corpus_wer: 3.33% RTF = 67.44
# spgispeech === cer: 0.99% wer: 2.00% corpus_wer: 2.03% RTF = 78.64
# gigaspeech === cer: 6.80% wer: 11.31% corpus_wer: 9.81% RTF = 64.04
# earnings22 (short) === cer: 12.63% wer: 15.70% corpus_wer: 11.02% RTF = 50.13
# Meanwhile === cer: 2.02% wer: 5.50% corpus_wer: 5.60% RTF = 69.19
# Tedlium (long) == cer: 1.53% wer: 2.56% corpus_wer: 2.97% RTF = 33.92
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
import argparse
import dataclasses
import time
import jiwer
import julius
import moshi.models
import torch
import tqdm
from datasets import Dataset, load_dataset
from whisper.normalizers import EnglishTextNormalizer
_NORMALIZER = EnglishTextNormalizer()
def get_text(sample):
possible_keys = [
"text",
"sentence",
"normalized_text",
"transcript",
"transcription",
]
for key in possible_keys:
if key in sample:
return sample[key]
raise ValueError(
f"Expected transcript column of either {possible_keys}."
f"Got sample with keys: {', '.join(sample.keys())}. Ensure a text column name is present in the dataset."
)
# The two functions below are adapted from https://github.com/huggingface/open_asr_leaderboard/blob/main/normalizer/data_utils.py
def normalize(batch):
batch["original_text"] = get_text(batch)
batch["norm_text"] = _NORMALIZER(batch["original_text"])
return batch
def is_target_text_in_range(ref):
if ref.strip() == "ignore time segment in scoring":
return False
else:
return ref.strip() != ""
# End of the adapted part
class AsrMetrics:
def __init__(self):
self.cer_sum = 0.0
self.wer_sum = 0.0
self.errors_sum = 0.0
self.total_words_sum = 0.0
self.num_sequences = 0.0
def update(self, hyp: str, ref: str) -> None:
normalized_ref = _NORMALIZER(ref)
normalized_hyp = _NORMALIZER(hyp)
this_wer = jiwer.wer(normalized_ref, normalized_hyp)
this_cer = jiwer.cer(normalized_ref, normalized_hyp)
measures = jiwer.compute_measures(normalized_ref, normalized_hyp)
self.wer_sum += this_wer
self.cer_sum += this_cer
self.errors_sum += (
measures["substitutions"] + measures["deletions"] + measures["insertions"]
)
self.total_words_sum += (
measures["substitutions"] + measures["deletions"] + measures["hits"]
)
self.num_sequences += 1
def compute(self) -> dict:
assert self.num_sequences > 0, (
"Unable to compute with total number of comparisons <= 0"
) # type: ignore
return {
"cer": (self.cer_sum / self.num_sequences),
"wer": (self.wer_sum / self.num_sequences),
"corpus_wer": (self.errors_sum / self.total_words_sum),
}
def __str__(self) -> str:
result = self.compute()
return " ".join(f"{k}: {100 * v:.2f}%" for k, v in result.items())
class Timer:
def __init__(self):
self.total = 0
self._start_time = None
def __enter__(self):
self._start_time = time.perf_counter()
return self
def __exit__(self, *_):
self.total += time.perf_counter() - self._start_time
self._start_time = None
@dataclasses.dataclass
class _DatasetInfo:
alias: str
name: str
config: str
split: str = "test"
_DATASETS = [
# Long-form datasets from distil-whisper
_DatasetInfo("rev16", "distil-whisper/rev16", "whisper_subset"),
_DatasetInfo("earnings21", "distil-whisper/earnings21", "full"),
_DatasetInfo("earnings22", "distil-whisper/earnings22", "full"),
_DatasetInfo("tedlium", "distil-whisper/tedlium-long-form", None),
_DatasetInfo("meanwhile", "distil-whisper/meanwhile", None),
# Short-form datasets from OpenASR leaderboard
_DatasetInfo("ami", "hf-audio/esb-datasets-test-only-sorted", "ami"),
_DatasetInfo(
"librispeech.clean",
"hf-audio/esb-datasets-test-only-sorted",
"librispeech",
split="test.clean",
),
_DatasetInfo(
"librispeech.other",
"hf-audio/esb-datasets-test-only-sorted",
"librispeech",
split="test.other",
),
_DatasetInfo("voxpopuli", "hf-audio/esb-datasets-test-only-sorted", "voxpopuli"),
_DatasetInfo("spgispeech", "hf-audio/esb-datasets-test-only-sorted", "spgispeech"),
_DatasetInfo("gigaspeech", "hf-audio/esb-datasets-test-only-sorted", "gigaspeech"),
_DatasetInfo("tedlium-short", "hf-audio/esb-datasets-test-only-sorted", "tedlium"),
_DatasetInfo(
"earnings22-short", "hf-audio/esb-datasets-test-only-sorted", "earnings22"
),
]
DATASET_MAP = {dataset.alias: dataset for dataset in _DATASETS}
def get_dataset(args) -> Dataset:
if args.dataset not in DATASET_MAP:
raise RuntimeError(f"Unknown dataset: {args.dataset}")
info = DATASET_MAP[args.dataset]
dataset = load_dataset(
info.name,
info.config,
split=info.split,
cache_dir=args.hf_cache_dir,
streaming=False,
token=True,
)
dataset = dataset.map(normalize)
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
return dataset
@torch.no_grad
def get_padded_batch(
audios: list[tuple[torch.Tensor, int]],
before_padding: float,
after_padding: float,
audio_encoder,
):
sample_rate = audio_encoder.sample_rate
max_len = 0
batch = []
durations = []
for audio, sr in audios:
durations.append(audio.shape[-1] / sr)
audio = julius.resample_frac(audio, int(sr), int(sample_rate))
audio = torch.nn.functional.pad(
audio, (int(before_padding * sample_rate), int(after_padding * sample_rate))
)
max_len = max(max_len, audio.shape[-1])
batch.append(audio)
target = max_len
if target % audio_encoder.frame_size != 0:
target = target + (
audio_encoder.frame_size - max_len % audio_encoder.frame_size
)
padded_batch = torch.stack(
[
torch.nn.functional.pad(audio, (0, target - audio.shape[-1]))
for audio in batch
]
)
return padded_batch
@torch.no_grad
def streaming_transcribe(
padded_batch: torch.Tensor,
mimi,
lm_gen,
):
bsz = padded_batch.shape[0]
text_tokens_acc = []
with mimi.streaming(bsz), lm_gen.streaming(bsz):
for offset in range(0, padded_batch.shape[-1], mimi.frame_size):
audio_chunk = padded_batch[:, offset : offset + mimi.frame_size]
audio_chunk = audio_chunk[:, None, :]
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_acc.append(text_tokens)
return torch.concat(text_tokens_acc, axis=-1)
def run_inference(
dataset,
mimi,
lm_gen,
tokenizer,
padding_token_id,
before_padding_sec,
after_padding_sec,
):
metrics = AsrMetrics()
audio_time = 0.0
inference_timer = Timer()
for batch in tqdm.tqdm(dataset.iter(args.batch_size)):
audio_data = list(
zip(
[torch.tensor(x["array"]).float() for x in batch["audio"]],
[x["sampling_rate"] for x in batch["audio"]],
)
)
audio_time += sum(audio.shape[-1] / sr for (audio, sr) in audio_data)
gt_transcripts = batch["original_text"]
padded_batch = get_padded_batch(
audio_data,
before_padding=before_padding_sec,
after_padding=after_padding_sec,
audio_encoder=mimi,
)
padded_batch = padded_batch.cuda()
with inference_timer:
text_tokens = streaming_transcribe(
padded_batch,
mimi=mimi,
lm_gen=lm_gen,
)
for batch_index in range(text_tokens.shape[0]):
utterance_tokens = text_tokens[batch_index, ...]
utterance_tokens = utterance_tokens[utterance_tokens > padding_token_id]
text = tokenizer.decode(utterance_tokens.cpu().numpy().tolist())
metrics.update(hyp=text, ref=gt_transcripts[batch_index])
return metrics, inference_timer.total, audio_time
def main(args):
torch.set_float32_matmul_precision("high")
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
dataset = get_dataset(args)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
# Putting in some conservative defaults
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
wer_metric, inference_time, audio_time = run_inference(
dataset,
mimi,
lm_gen,
tokenizer,
padding_token_id,
audio_silence_prefix_seconds,
audio_delay_seconds + 0.5,
)
print(wer_metric, f"RTF = {audio_time / inference_time:.2f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT inference.")
parser.add_argument(
"--dataset",
required=True,
choices=DATASET_MAP.keys(),
help="Dataset to run inference on.",
)
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from."
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--batch-size",
type=int,
help="Batch size.",
default=32,
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
parser.add_argument("--hf-cache-dir", type=str, help="HuggingFace cache folder.")
args = parser.parse_args()
main(args)

View File

@ -1,100 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sentencepiece",
# "sounddevice",
# "sphn",
# ]
# ///
import argparse
import json
import mlx.core as mx
import mlx.nn as nn
import sentencepiece
import sphn
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
audio, _ = sphn.read(args.in_file, sample_rate=24000)
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj:
lm_config = json.load(fobj)
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
moshi_name = lm_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
lm_config = models.LmConfig.from_config_dict(lm_config)
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
if moshi_weights.endswith(".q4.safetensors"):
nn.quantize(model, bits=4, group_size=32)
elif moshi_weights.endswith(".q8.safetensors"):
nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}")
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {text_tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer) # type: ignore
print(f"loading the audio tokenizer {mimi_weights}")
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
print("warming up the model")
model.warmup()
gen = models.LmGen(
model=model,
max_steps=args.max_steps,
text_sampler=utils.Sampler(top_k=25, temp=0),
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
check=False,
)
print(f"starting inference {audio.shape}")
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
last_print_was_vad = False
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
block = audio[:, None, start_idx : start_idx + 1920]
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
_text = None
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
print()

View File

@ -1,247 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "julius",
# "librosa",
# "soundfile",
# "moshi==0.2.11",
# ]
# ///
"""An example script that illustrates how one can get per-word timestamps from
Kyutai STT models.
"""
import argparse
import dataclasses
import itertools
import math
import julius
import moshi.models
import sphn
import time
import torch
@dataclasses.dataclass
class TimestampedText:
text: str
timestamp: tuple[float, float]
def __str__(self):
return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})"
def tokens_to_timestamped_text(
text_tokens,
tokenizer,
frame_rate,
end_of_padding_id,
padding_token_id,
offset_seconds,
) -> list[TimestampedText]:
text_tokens = text_tokens.cpu().view(-1)
# Normally `end_of_padding` tokens indicate word boundaries.
# Everything between them should be a single word;
# the time offset of the those tokens correspond to word start and
# end timestamps (minus silence prefix and audio delay).
#
# However, in rare cases some complexities could arise. Firstly,
# for words that are said quickly but are represented with
# multiple tokens, the boundary might be omitted. Secondly,
# for the very last word the end boundary might not happen.
# Below is a code snippet that handles those situations a bit
# more carefully.
sequence_timestamps = []
def _tstmp(start_position, end_position):
return (
max(0, start_position / frame_rate - offset_seconds),
max(0, end_position / frame_rate - offset_seconds),
)
def _decode(t):
t = t[t > padding_token_id]
return tokenizer.decode(t.numpy().tolist())
def _decode_segment(start, end):
nonlocal text_tokens
nonlocal sequence_timestamps
text = _decode(text_tokens[start:end])
words_inside_segment = text.split()
if len(words_inside_segment) == 0:
return
if len(words_inside_segment) == 1:
# Single word within the boundaries, the general case
sequence_timestamps.append(
TimestampedText(text=text, timestamp=_tstmp(start, end))
)
else:
# We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`.
# We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has.
for adjacent_word in words_inside_segment[:-1]:
n_tokens = len(tokenizer.encode(adjacent_word))
sequence_timestamps.append(
TimestampedText(
text=adjacent_word, timestamp=_tstmp(start, start + n_tokens)
)
)
start += n_tokens
# The last word takes everything until the boundary
adjacent_word = words_inside_segment[-1]
sequence_timestamps.append(
TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end))
)
(segment_boundaries,) = torch.where(text_tokens == end_of_padding_id)
if not segment_boundaries.numel():
return []
for i in range(len(segment_boundaries) - 1):
segment_start = int(segment_boundaries[i]) + 1
segment_end = int(segment_boundaries[i + 1])
_decode_segment(segment_start, segment_end)
last_segment_start = segment_boundaries[-1] + 1
boundary_token = torch.tensor([tokenizer.eos_id()])
(end_of_last_segment,) = torch.where(
torch.isin(text_tokens[last_segment_start:], boundary_token)
)
if not end_of_last_segment.numel():
# upper-bound either end of the audio or 1 second duration, whicher is smaller
last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate)
else:
last_segment_end = last_segment_start + end_of_last_segment[0]
_decode_segment(last_segment_start, last_segment_end)
return sequence_timestamps
def main(args):
if args.vad and args.hf_repo is None:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
audio, input_sample_rate = sphn.read(args.in_file)
audio = torch.from_numpy(audio).to(args.device)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0:
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
audio = torch.nn.functional.pad(audio, (0, to_pad))
text_tokens_accum = []
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
silence_chunk = torch.zeros(
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
)
chunks = itertools.chain(
itertools.repeat(silence_chunk, n_prefix_chunks),
torch.split(audio[:, None], mimi.frame_size, dim=-1),
itertools.repeat(silence_chunk, n_suffix_chunks),
)
start_time = time.time()
nchunks = 0
last_print_was_vad = False
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in chunks:
nchunks += 1
audio_tokens = mimi.encode(audio_chunk)
if args.vad:
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_tokens = lm_gen.step(audio_tokens)
text_token = text_tokens[0, 0, 0].cpu().item()
if text_token not in (0, 3):
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
dt = time.time() - start_time
print(
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
)
timed_text = tokens_to_timestamped_text(
utterance_tokens,
tokenizer,
mimi.frame_rate,
end_of_padding_id=0,
padding_token_id=padding_token_id,
offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds,
)
decoded = " ".join([str(t) for t in timed_text])
print(decoded)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.")
parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from. "
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
main(args)

View File

@ -1,135 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# ]
# ///
import argparse
import asyncio
import time
import msgpack
import numpy as np
import sphn
import websockets
SAMPLE_RATE = 24000
FRAME_SIZE = 1920 # Send data in chunks
def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
return pcm_data[0]
async def receive_messages(websocket):
transcript = []
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
if data["type"] == "Step":
# This message contains the signal from the semantic VAD, and tells us how
# much audio the server has already processed. We don't use either here.
continue
if data["type"] == "Word":
print(data["text"], end=" ", flush=True)
transcript.append(
{
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
}
)
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
if data["type"] == "Marker":
# Received marker, stopping stream
break
return transcript
async def send_messages(websocket, rtf: float):
audio_data = load_and_process_audio(args.in_file)
async def send_audio(audio: np.ndarray):
await websocket.send(
msgpack.packb(
{"type": "Audio", "pcm": [float(x) for x in audio]},
use_single_float=True,
)
)
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
await send_audio([0.0] * SAMPLE_RATE)
start_time = time.time()
for i in range(0, len(audio_data), FRAME_SIZE):
await send_audio(audio_data[i : i + FRAME_SIZE])
expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
current_time = time.time()
if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
for _ in range(5):
await send_audio([0.0] * SAMPLE_RATE)
# Send a marker to indicate the end of the stream.
await websocket.send(
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
)
# We'll get back the marker once the corresponding audio has been transcribed,
# accounting for the delay of the model. That's why we need to send some silence
# after the marker, because the model will not return the marker immediately.
for _ in range(35):
await send_audio([0.0] * SAMPLE_RATE)
async def stream_audio(url: str, api_key: str, rtf: float):
"""Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
_, transcript = await asyncio.gather(send_task, receive_task)
return transcript
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
parser.add_argument(
"--rtf",
type=float,
default=1.01,
help="The real-time factor of how fast to feed in the audio.",
)
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
print()
print()
for word in transcript:
print(
f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}"
)

View File

@ -1,187 +0,0 @@
"""An example script that illustrates how one can prompt Kyutai STT models."""
import argparse
import itertools
import math
from collections import deque
import julius
import moshi.models
import sphn
import torch
import tqdm
class PromptHook:
def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)):
self.tokenizer = tokenizer
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
self.padding_tokens = padding_tokens
def on_token(self, token):
if not self.prefix_enforce:
return
token = token.item()
if token in self.padding_tokens:
pass
elif token == self.prefix_enforce[0]:
self.prefix_enforce.popleft()
else:
assert False
def on_logits(self, logits):
if not self.prefix_enforce:
return
mask = torch.zeros_like(logits, dtype=torch.bool)
for t in self.padding_tokens:
mask[..., t] = True
mask[..., self.prefix_enforce[0]] = True
logits[:] = torch.where(mask, logits, float("-inf"))
def main(args):
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
if args.prompt_text:
prompt_hook = PromptHook(tokenizer, args.prompt_text)
lm_gen = moshi.models.LMGen(
lm,
temp=0,
temp_text=0.0,
on_text_hook=prompt_hook.on_token,
on_text_logits_hook=prompt_hook.on_logits,
)
else:
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
def _load_and_process(path):
audio, input_sample_rate = sphn.read(path)
audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0:
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
audio = torch.nn.functional.pad(audio, (0, to_pad))
return audio
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
silence_chunk = torch.zeros(
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
)
audio = _load_and_process(args.file)
if args.prompt_file:
audio_prompt = _load_and_process(args.prompt_file)
else:
audio_prompt = None
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
if audio_prompt is not None:
chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1))
# adding a bit (0.8s) of silence to separate prompt and the actual audio
chain.append(itertools.repeat(silence_chunk, 10))
chain += [
torch.split(audio[:, None, :], mimi.frame_size, dim=-1),
itertools.repeat(silence_chunk, n_suffix_chunks),
]
chunks = itertools.chain(*chain)
text_tokens_accum = []
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
text_tokens = utterance_tokens.cpu().view(-1)
# if we have an audio prompt and we don't want to have it in the transcript,
# we should cut the corresponding number of frames from the output tokens.
# However, there is also some amount of padding that happens before it
# due to silence_prefix and audio_delay. Normally it is ignored in detokenization,
# but now we should account for it to find the position of the prompt transcript.
if args.cut_prompt_transcript and audio_prompt is not None:
prompt_frames = audio_prompt.shape[1] // mimi.frame_size
no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
text_tokens = text_tokens[prompt_frames + no_prompt_offset :]
text = tokenizer.decode(
text_tokens[text_tokens > padding_token_id].numpy().tolist()
)
print(text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.")
parser.add_argument(
"--file",
required=True,
help="File to transcribe.",
)
parser.add_argument(
"--prompt_file",
required=False,
help="Audio of the prompt.",
)
parser.add_argument(
"--prompt_text",
required=False,
help="Text of the prompt.",
)
parser.add_argument(
"--cut-prompt-transcript",
action="store_true",
help="Cut the prompt from the output transcript",
)
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from. "
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
main(args)

View File

@ -1,116 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "rustymimi",
# "sentencepiece",
# "sounddevice",
# ]
# ///
import argparse
import json
import queue
import mlx.core as mx
import mlx.nn as nn
import rustymimi
import sentencepiece
import sounddevice as sd
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj:
lm_config = json.load(fobj)
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
moshi_name = lm_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
lm_config = models.LmConfig.from_config_dict(lm_config)
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
if moshi_weights.endswith(".q4.safetensors"):
nn.quantize(model, bits=4, group_size=32)
elif moshi_weights.endswith(".q8.safetensors"):
nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}")
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore
print(f"loading the audio tokenizer {mimi_weights}")
generated_codebooks = lm_config.generated_codebooks
other_codebooks = lm_config.other_codebooks
mimi_codebooks = max(generated_codebooks, other_codebooks)
audio_tokenizer = rustymimi.Tokenizer(mimi_weights, num_codebooks=mimi_codebooks) # type: ignore
print("warming up the model")
model.warmup()
gen = models.LmGen(
model=model,
max_steps=args.max_steps,
text_sampler=utils.Sampler(top_k=25, temp=0),
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
check=False,
)
block_queue = queue.Queue()
def audio_callback(indata, _frames, _time, _status):
block_queue.put(indata.copy())
print("recording audio from microphone, speak to get your words transcribed")
last_print_was_vad = False
with sd.InputStream(
channels=1,
dtype="float32",
samplerate=24000,
blocksize=1920,
callback=audio_callback,
):
while True:
block = block_queue.get()
block = block[None, :, 0]
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks
]
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
_text = None
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False

View File

@ -1,135 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sounddevice",
# "websockets",
# ]
# ///
import argparse
import asyncio
import signal
import msgpack
import numpy as np
import sounddevice as sd
import websockets
SAMPLE_RATE = 24000
# The VAD has several prediction heads, each of which tries to determine whether there
# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
PAUSE_PREDICTION_HEAD_INDEX = 2
async def receive_messages(websocket, show_vad: bool = False):
"""Receive and process messages from the WebSocket server."""
try:
speech_started = False
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
# The Step message only gets sent if the model has semantic VAD available
if data["type"] == "Step" and show_vad:
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
if pause_prediction > 0.5 and speech_started:
print("| ", end="", flush=True)
speech_started = False
elif data["type"] == "Word":
print(data["text"], end=" ", flush=True)
speech_started = True
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
async def send_messages(websocket, audio_queue):
"""Send audio data from microphone to WebSocket server."""
try:
# Start by draining the queue to avoid lags
while not audio_queue.empty():
await audio_queue.get()
print("Starting the transcription")
while True:
audio_data = await audio_queue.get()
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str, show_vad: bool):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
audio_queue = asyncio.Queue()
loop = asyncio.get_event_loop()
def audio_callback(indata, frames, time, status):
loop.call_soon_threadsafe(
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
)
# Start audio stream
with sd.InputStream(
samplerate=SAMPLE_RATE,
channels=1,
dtype="float32",
callback=audio_callback,
blocksize=1920, # 80ms blocks
):
headers = {"kyutai-api-key": api_key}
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, audio_queue))
receive_task = asyncio.create_task(
receive_messages(websocket, show_vad=show_vad)
)
await asyncio.gather(send_task, receive_task)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Real-time microphone transcription")
parser.add_argument(
"--url",
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
parser.add_argument(
"--list-devices", action="store_true", help="List available audio devices"
)
parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)"
)
parser.add_argument(
"--show-vad",
action="store_true",
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
)
args = parser.parse_args()
def handle_sigint(signum, frame):
print("Interrupted by user") # Don't complain about KeyboardInterrupt
exit(0)
signal.signal(signal.SIGINT, handle_sigint)
if args.list_devices:
print("Available audio devices:")
print(sd.query_devices())
exit(0)
if args.device is not None:
sd.default.device[0] = args.device # Set input device
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.api_key, args.show_vad))

View File

@ -1,206 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sounddevice",
# ]
# ///
import argparse
import json
import queue
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sounddevice as sd
import sphn
from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get
def log(level: str, msg: str):
print(make_log(level, msg))
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
log("info", "retrieving checkpoints")
raw_config = hf_get("config.json", args.hf_repo)
with open(hf_get(raw_config), "r") as fobj:
raw_config = json.load(fobj)
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
moshi_name = raw_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
log("info", f"loading model weights from {moshi_weights}")
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
if args.quantize is not None:
log("info", f"quantizing model to {args.quantize} bits")
nn.quantize(model.depformer, bits=args.quantize)
for layer in model.transformer.layers:
nn.quantize(layer.self_attn, bits=args.quantize)
nn.quantize(layer.gating, bits=args.quantize)
log("info", f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
log("info", f"loading the audio tokenizer {mimi_weights}")
generated_codebooks = lm_config.generated_codebooks
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
cfg_coef_conditioning = None
tts_model = TTSModel(
model,
audio_tokenizer,
text_tokenizer,
voice_repo=args.voice_repo,
temp=0.6,
cfg_coef=1,
max_padding=8,
initial_padding=2,
final_padding=2,
padding_bonus=0,
raw_config=raw_config,
)
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.0
cfg_is_no_text = False
cfg_is_no_prefix = False
else:
cfg_is_no_text = True
cfg_is_no_prefix = True
mimi = tts_model.mimi
log("info", f"reading input from {args.inp}")
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
all_entries = [tts_model.prepare_script([text_to_tts])]
if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
def run():
log("info", "starting the inference loop")
begin = time.time()
result = tts_model.generate(
all_entries,
all_attributes,
cfg_is_no_prefix=cfg_is_no_prefix,
cfg_is_no_text=cfg_is_no_text,
on_frame=_on_frame,
)
frames = mx.concat(result.frames, axis=-1)
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
time_taken = time.time() - begin
total_speed = total_duration / time_taken
log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
return result
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
time.sleep(3)
while True:
if wav_frames.qsize() == 0:
break
time.sleep(1)
else:
run()
frames = []
while True:
try:
frames.append(wav_frames.get_nowait())
except queue.Empty:
break
wav = np.concat(frames, -1)
sphn.write_wav(args.out, wav, mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -1,317 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import json
import queue
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sounddevice as sd
import sphn
import typing as tp
from moshi_mlx import models
from moshi_mlx.models.generate import LmGen
from moshi_mlx.client_utils import make_log
from moshi_mlx.modules.conditioner import (
ConditionAttributes,
ConditionTensor,
dropout_all_conditions,
)
from moshi_mlx.utils.sampling import Sampler
from moshi_mlx.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
script_to_entries,
)
from moshi_mlx.utils.loaders import hf_get
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[mx.array], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
self.ct = None
self.cross_attention_src = None
for _attr in attributes:
for _key, _value in _attr.text.items():
_ct = tts_model.lm.condition_provider.condition_tensor(_key, _value)
if self.ct is None:
self.ct = _ct
else:
self.ct = ConditionTensor(self.ct.tensor + _ct.tensor)
for _key, _value in _attr.tensor.items():
_conditioner = tts_model.lm.condition_provider.conditioners[_key]
_ca_src = _conditioner.condition(_value)
if self.cross_attention_src is None:
self.cross_attention_src = _ca_src
else:
raise ValueError("multiple cross-attention conditioners")
def _on_audio_hook(audio_tokens):
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[0]):
delay = delays[q]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = mx.array(out_tokens, dtype=mx.int64)
self.lm_gen = LmGen(
tts_model.lm,
max_steps=tts_model.max_gen_length,
text_sampler=Sampler(temp=tts_model.temp),
audio_sampler=Sampler(temp=tts_model.temp),
cfg_coef=tts_model.cfg_coef,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
# TODO(laurent):
# cfg_is_masked_until=cfg_is_masked_until,
# cfg_is_no_text=cfg_is_no_text,
)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = (
mx.ones((1, missing), dtype=mx.int64)
* self.tts_model.machine.token_ids.zero
)
self.lm_gen.step(
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
)
frame = self.lm_gen.last_audio_tokens()
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
def log(level: str, msg: str):
print(make_log(level, msg))
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
log("info", "retrieving checkpoints")
raw_config = hf_get("config.json", args.hf_repo)
with open(hf_get(raw_config), "r") as fobj:
raw_config = json.load(fobj)
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
moshi_name = raw_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
log("info", f"loading model weights from {moshi_weights}")
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
if args.quantize is not None:
log("info", f"quantizing model to {args.quantize} bits")
nn.quantize(model.depformer, bits=args.quantize)
for layer in model.transformer.layers:
nn.quantize(layer.self_attn, bits=args.quantize)
nn.quantize(layer.gating, bits=args.quantize)
log("info", f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
log("info", f"loading the audio tokenizer {mimi_weights}")
generated_codebooks = lm_config.generated_codebooks
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
cfg_coef_conditioning = None
tts_model = TTSModel(
model,
audio_tokenizer,
text_tokenizer,
voice_repo=args.voice_repo,
temp=0.6,
cfg_coef=1,
max_padding=8,
initial_padding=2,
final_padding=2,
padding_bonus=0,
raw_config=raw_config,
)
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.0
mimi = tts_model.mimi
log("info", "reading input from stdin")
if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
gen = TTSGen(tts_model, all_attributes, on_frame=_on_frame)
def run():
log("info", "starting the inference loop")
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
while True:
if wav_frames.qsize() == 0:
break
time.sleep(1)
else:
run()
frames = []
while True:
try:
frames.append(wav_frames.get_nowait())
except queue.Empty:
break
wav = np.concat(frames, -1)
sphn.write_wav(args.out, wav, mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -1,140 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "moshi==0.2.11",
# "torch",
# "sphn",
# "sounddevice",
# ]
# ///
import argparse
import sys
import numpy as np
import queue
import sphn
import time
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=args.device
)
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text = fobj.read().strip()
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
_frames_cnt = 0
if args.out == "-":
# Stream the audio to the speakers using sounddevice.
import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
with tts_model.mimi.streaming(1):
tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
time.sleep(3)
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else:
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
result = tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
with tts_model.mimi.streaming(1), torch.no_grad():
pcms = []
for frame in result.frames[tts_model.delay_steps :]:
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1))
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -1,261 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "moshi==0.2.11",
# "torch",
# "sphn",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import sys
import numpy as np
import queue
import sphn
import time
import torch
import typing as tp
from moshi.models.loaders import CheckpointInfo
from moshi.conditioners import dropout_all_conditions
from moshi.models.lm import LMGen
from moshi.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
ConditionAttributes,
script_to_entries,
)
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[torch.Tensor], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
prepared = tts_model.lm.condition_provider.prepare(attributes)
condition_tensors = tts_model.lm.condition_provider(prepared)
def _on_text_logits_hook(text_logits):
if tts_model.padding_bonus:
text_logits[..., tts_model.machine.token_ids.pad] += (
tts_model.padding_bonus
)
return text_logits
def _on_audio_hook(audio_tokens):
audio_offset = tts_model.lm.audio_offset
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[1]):
delay = delays[q + audio_offset]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[:, q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = torch.tensor(
out_tokens, dtype=torch.long, device=text_tokens.device
)
tts_model.lm.dep_q = tts_model.n_q
self.lm_gen = LMGen(
tts_model.lm,
temp=tts_model.temp,
temp_text=tts_model.temp,
cfg_coef=tts_model.cfg_coef,
condition_tensors=condition_tensors,
on_text_logits_hook=_on_text_logits_hook,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
cfg_is_masked_until=None,
cfg_is_no_text=True,
)
self.lm_gen.streaming_forever(1)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = torch.full(
(1, missing, 1),
self.tts_model.machine.token_ids.zero,
dtype=torch.long,
device=self.tts_model.lm.device,
)
frame = self.lm_gen.step(input_tokens)
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
@torch.no_grad()
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=args.device
)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
if args.out == "-":
# Stream the audio to the speakers using sounddevice.
import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
) and tts_model.mimi.streaming(1):
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else:
pcms = []
def _on_frame(frame: torch.Tensor):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0]))
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
with tts_model.mimi.streaming(1):
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -1,178 +0,0 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# "sounddevice",
# "tqdm",
# ]
# ///
import argparse
import asyncio
import sys
from urllib.parse import urlencode
import msgpack
import numpy as np
import sounddevice as sd
import sphn
import tqdm
import websockets
SAMPLE_RATE = 24000
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
AUTH_TOKEN = "public_token"
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
accumulated_samples = 0
last_seconds = 0
async for message_bytes in websocket:
msg = msgpack.unpackb(message_bytes)
if msg["type"] == "Audio":
pcm = np.array(msg["pcm"]).astype(np.float32)
await output_queue.put(pcm)
accumulated_samples += len(msg["pcm"])
current_seconds = accumulated_samples // SAMPLE_RATE
if current_seconds > last_seconds:
pbar.update(current_seconds - last_seconds)
last_seconds = current_seconds
print("End of audio.")
await output_queue.put(None) # Signal end of audio
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
if out == "-":
should_exit = False
def audio_callback(outdata, _a, _b, _c):
nonlocal should_exit
try:
pcm_data = output_queue.get_nowait()
if pcm_data is not None:
outdata[:, 0] = pcm_data
else:
should_exit = True
outdata[:] = 0
except asyncio.QueueEmpty:
outdata[:] = 0
with sd.OutputStream(
samplerate=SAMPLE_RATE,
blocksize=1920,
channels=1,
callback=audio_callback,
):
while True:
if should_exit:
break
await asyncio.sleep(1)
else:
frames = []
while True:
item = await output_queue.get()
if item is None:
break
frames.append(item)
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
print(f"Saved audio to {out}")
async def read_lines_from_stdin():
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
loop = asyncio.get_running_loop()
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while True:
line = await reader.readline()
if not line:
break
yield line.decode().rstrip()
async def read_lines_from_file(path: str):
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
with open(path, "r", encoding="utf-8") as f:
for line in f:
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
await asyncio.to_thread(producer)
while True:
line = await queue.get()
if line is None:
break
yield line
async def get_lines(source: str):
if source == "-":
async for line in read_lines_from_stdin():
yield line
else:
async for line in read_lines_from_file(source):
yield line
async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--url",
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
args = parser.parse_args()
params = {"voice": args.voice, "format": "PcmMessagePack"}
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri)
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket:
print("connected")
async def send_loop():
print("go send")
async for line in get_lines(args.inp):
for word in line.split():
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
send_task = asyncio.create_task(send_loop())
await asyncio.gather(receive_task, output_audio_task, send_task)
if __name__ == "__main__":
asyncio.run(websocket_client())

633
stt-rs/Cargo.lock generated
View File

@ -97,12 +97,6 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "audiopus_sys"
version = "0.2.2"
@ -239,7 +233,7 @@ dependencies = [
"metal 0.27.0",
"num-traits",
"num_cpus",
"rand",
"rand 0.9.1",
"rand_distr",
"rayon",
"safetensors",
@ -301,7 +295,7 @@ dependencies = [
"candle-nn",
"fancy-regex",
"num-traits",
"rand",
"rand 0.9.1",
"rayon",
"serde",
"serde_json",
@ -482,23 +476,23 @@ dependencies = [
[[package]]
name = "dirs"
version = "6.0.0"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.5.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab"
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.60.2",
"windows-sys 0.48.0",
]
[[package]]
@ -613,12 +607,6 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
@ -670,48 +658,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
@ -747,13 +699,9 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@ -1031,25 +979,6 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "h2"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http",
"indexmap",
"slab",
"tokio",
"tokio-util 0.7.15",
"tracing",
]
[[package]]
name = "half"
version = "2.6.0"
@ -1060,7 +989,7 @@ dependencies = [
"cfg-if",
"crunchy",
"num-traits",
"rand",
"rand 0.9.1",
"rand_distr",
]
@ -1084,144 +1013,19 @@ checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "hf-hub"
version = "0.4.3"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [
"dirs",
"futures",
"http",
"indicatif",
"libc",
"log",
"native-tls",
"num_cpus",
"rand",
"reqwest",
"rand 0.8.5",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"thiserror 1.0.69",
"ureq",
"windows-sys 0.60.2",
]
[[package]]
name = "http"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"pin-project-lite",
]
[[package]]
name = "httparse"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "hyper"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb"
dependencies = [
"base64",
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http",
"http-body",
"hyper",
"ipnet",
"libc",
"percent-encoding",
"pin-project-lite",
"socket2",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@ -1354,22 +1158,6 @@ dependencies = [
"web-time",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "iri-string"
version = "0.7.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2"
dependencies = [
"memchr",
"serde",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
@ -1557,12 +1345,6 @@ dependencies = [
"paste",
]
[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
@ -1777,7 +1559,7 @@ dependencies = [
"futures-io",
"pin-project",
"tokio",
"tokio-util 0.6.10",
"tokio-util",
]
[[package]]
@ -2040,14 +1822,35 @@ version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha",
"rand_core",
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core 0.6.4",
]
[[package]]
@ -2057,7 +1860,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom 0.2.16",
]
[[package]]
@ -2076,7 +1888,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand",
"rand 0.9.1",
]
[[package]]
@ -2143,13 +1955,13 @@ dependencies = [
[[package]]
name = "redox_users"
version = "0.5.0"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b"
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
"getrandom 0.2.16",
"libredox",
"thiserror 2.0.12",
"thiserror 1.0.69",
]
[[package]]
@ -2181,49 +1993,6 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eabf4c97d9130e2bf606614eb937e86edac8292eaa6f422f995d7e8de1eb1813"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-util 0.7.15",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
]
[[package]]
name = "ring"
version = "0.17.14"
@ -2318,12 +2087,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
[[package]]
name = "ryu"
version = "1.0.20"
@ -2460,18 +2223,6 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd"
dependencies = [
"form_urlencoded",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "shlex"
version = "1.3.0"
@ -2509,17 +2260,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "socks"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
dependencies = [
"byteorder",
"libc",
"winapi",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
@ -2761,15 +2501,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
version = "0.13.2"
@ -2809,27 +2540,6 @@ dependencies = [
"walkdir",
]
[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags 2.9.1",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tempfile"
version = "3.20.0"
@ -2922,26 +2632,6 @@ dependencies = [
"syn 2.0.103",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.6.10"
@ -2957,19 +2647,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "toml_datetime"
version = "0.6.11"
@ -2987,51 +2664,6 @@ dependencies = [
"winnow",
]
[[package]]
name = "tower"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2"
dependencies = [
"bitflags 2.9.1",
"bytes",
"futures-util",
"http",
"http-body",
"iri-string",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
version = "0.1.41"
@ -3073,12 +2705,6 @@ dependencies = [
"strength_reduce",
]
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "ug"
version = "0.4.0"
@ -3160,7 +2786,6 @@ dependencies = [
"rustls-pki-types",
"serde",
"serde_json",
"socks",
"url",
"webpki-roots 0.26.11",
]
@ -3210,15 +2835,6 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.11.1+wasi-snapshot-preview1"
@ -3242,7 +2858,6 @@ checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
"rustversion",
"wasm-bindgen-macro",
]
@ -3260,19 +2875,6 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
@ -3305,29 +2907,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
@ -3356,22 +2935,6 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
@ -3382,44 +2945,12 @@ dependencies = [
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-link"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
[[package]]
name = "windows-registry"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-result"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
dependencies = [
"windows-link",
"windows-targets 0.48.5",
]
[[package]]
@ -3441,12 +2972,18 @@ dependencies = [
]
[[package]]
name = "windows-sys"
version = "0.60.2"
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows-targets 0.53.2",
"windows_aarch64_gnullvm 0.48.5",
"windows_aarch64_msvc 0.48.5",
"windows_i686_gnu 0.48.5",
"windows_i686_msvc 0.48.5",
"windows_x86_64_gnu 0.48.5",
"windows_x86_64_gnullvm 0.48.5",
"windows_x86_64_msvc 0.48.5",
]
[[package]]
@ -3481,6 +3018,12 @@ dependencies = [
"windows_x86_64_msvc 0.53.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
@ -3493,6 +3036,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
@ -3505,6 +3054,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
@ -3529,6 +3084,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
@ -3541,6 +3102,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
@ -3553,6 +3120,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
@ -3565,6 +3138,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"

View File

@ -8,7 +8,7 @@ anyhow = "1.0"
candle = { version = "0.9.1", package = "candle-core" }
candle-nn = "0.9.1"
clap = { version = "4.4.12", features = ["derive"] }
hf-hub = "0.4.3"
hf-hub = "0.3.2"
kaudio = "0.2.1"
moshi = "0.6.1"
sentencepiece = "0.11.3"

View File

@ -18,14 +18,6 @@ struct Args {
/// Run the model on cpu.
#[arg(long)]
cpu: bool,
/// Display word level timestamps.
#[arg(long)]
timestamps: bool,
/// Display the level of voice activity detection (VAD).
#[arg(long)]
vad: bool,
}
fn device(cpu: bool) -> Result<Device> {
@ -40,12 +32,6 @@ fn device(cpu: bool) -> Result<Device> {
}
}
#[derive(Debug, serde::Deserialize)]
struct SttConfig {
audio_silence_prefix_seconds: f64,
audio_delay_seconds: f64,
}
#[derive(Debug, serde::Deserialize)]
struct Config {
mimi_name: String,
@ -59,11 +45,10 @@ struct Config {
num_heads: usize,
num_layers: usize,
causal: bool,
stt_config: SttConfig,
}
impl Config {
fn model_config(&self, vad: bool) -> moshi::lm::Config {
fn model_config(&self) -> moshi::lm::Config {
let lm_cfg = moshi::transformer::Config {
d_model: self.dim,
num_heads: self.num_heads,
@ -88,14 +73,6 @@ impl Config {
max_seq_len: 4096 * 4,
shared_cross_attn: false,
};
let extra_heads = if vad {
Some(moshi::lm::ExtraHeadsConfig {
num_heads: 4,
dim: 6,
})
} else {
None
};
moshi::lm::Config {
transformer: lm_cfg,
depformer: None,
@ -104,7 +81,7 @@ impl Config {
text_out_vocab_size: self.text_card,
audio_codebooks: self.n_q,
conditioners: Default::default(),
extra_heads,
extra_heads: None,
}
}
}
@ -112,19 +89,16 @@ impl Config {
struct Model {
state: moshi::asr::State,
text_tokenizer: sentencepiece::SentencePieceProcessor,
timestamps: bool,
vad: bool,
config: Config,
dev: Device,
}
impl Model {
fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> {
let dtype = dev.bf16_default_to_f32();
// Retrieve the model files from the Hugging Face Hub
let api = hf_hub::api::sync::Api::new()?;
let repo = api.model(args.hf_repo.to_string());
let repo = api.model(hf_repo.to_string());
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
let tokenizer_file = repo.get(&config.tokenizer_name)?;
@ -136,86 +110,53 @@ impl Model {
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
let lm = moshi::lm::LmModel::new(
&config.model_config(args.vad),
&config.model_config(),
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
)?;
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?;
Ok(Model {
state,
config,
text_tokenizer,
timestamps: args.timestamps,
vad: args.vad,
dev: dev.clone(),
})
}
fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
fn run(&mut self, pcm: &[f32]) -> Result<()> {
use std::io::Write;
// Add the silence prefix to the audio.
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
let silence_len =
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
pcm.splice(0..0, vec![0.0; silence_len]);
}
// Add some silence at the end to ensure all the audio is processed.
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
pcm.resize(pcm.len() + suffix + 24000, 0.0);
let mut last_word = None;
let mut printed_eot = false;
for pcm in pcm.chunks(1920) {
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
let mut prev_text_token = 0;
for asr_msg in asr_msgs.iter() {
match asr_msg {
moshi::asr::AsrMsg::Step { prs, .. } => {
// prs is the probability of having no voice activity for different time
// horizons.
// In kyutai/stt-1b-en_fr-candle, these horizons are 0.5s, 1s, 2s, and 3s.
if self.vad && prs[2][0] > 0.5 && !printed_eot {
printed_eot = true;
if !self.timestamps {
print!(" <endofturn pr={}>", prs[2][0]);
} else {
println!("<endofturn pr={}>", prs[2][0]);
}
}
}
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
printed_eot = false;
if self.timestamps {
if let Some((word, start_time)) = last_word.take() {
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
}
}
}
moshi::asr::AsrMsg::Word {
tokens, start_time, ..
} => {
printed_eot = false;
let word = self
.text_tokenizer
.decode_piece_ids(tokens)
.unwrap_or_else(|_| String::new());
if !self.timestamps {
print!(" {word}");
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {}
moshi::asr::AsrMsg::Word { tokens, .. } => {
for &text_token in tokens.iter() {
let s = {
let prev_ids =
self.text_tokenizer.decode_piece_ids(&[prev_text_token]);
let ids = self
.text_tokenizer
.decode_piece_ids(&[prev_text_token, text_token]);
prev_text_token = text_token;
prev_ids.and_then(|prev_ids| {
ids.map(|ids| {
if ids.len() > prev_ids.len() {
ids[prev_ids.len()..].to_string()
} else {
String::new()
}
})
})?
};
print!("{s}");
std::io::stdout().flush()?
} else {
if let Some((word, prev_start_time)) = last_word.take() {
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
}
last_word = Some((word, *start_time));
}
}
}
}
}
if let Some((word, start_time)) = last_word.take() {
println!("[{start_time:5.2}- ] {word}");
}
println!();
Ok(())
}
@ -227,15 +168,17 @@ fn main() -> Result<()> {
println!("Using device: {:?}", device);
println!("Loading audio file from: {}", args.in_file);
let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?;
let pcm = if sample_rate != 24_000 {
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
let mut pcm = if sample_rate != 24_000 {
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
};
// Add some silence at the end to ensure all the audio is processed.
pcm.resize(pcm.len() + 1920 * 32, 0.0);
println!("Loading model from repository: {}", args.hf_repo);
let mut model = Model::load_from_hf(&args, &device)?;
let mut model = Model::load_from_hf(&args.hf_repo, &device)?;
println!("Running inference");
model.run(pcm)?;
model.run(&pcm)?;
Ok(())
}

View File

@ -1,238 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gJEMjPgeI-rw",
"outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
},
"outputs": [],
"source": [
"!pip install moshi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CA4K5iDFJcqJ",
"outputId": "b609843a-a193-4729-b099-5f8780532333"
},
"outputs": [],
"source": [
"!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VA3Haix3IZ8Q"
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"import time\n",
"import sentencepiece\n",
"import sphn\n",
"import textwrap\n",
"import torch\n",
"\n",
"from moshi.models import loaders, MimiModel, LMModel, LMGen"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9AK5zBMTI9bw"
},
"outputs": [],
"source": [
"@dataclass\n",
"class InferenceState:\n",
" mimi: MimiModel\n",
" text_tokenizer: sentencepiece.SentencePieceProcessor\n",
" lm_gen: LMGen\n",
"\n",
" def __init__(\n",
" self,\n",
" mimi: MimiModel,\n",
" text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
" lm: LMModel,\n",
" batch_size: int,\n",
" device: str | torch.device,\n",
" ):\n",
" self.mimi = mimi\n",
" self.text_tokenizer = text_tokenizer\n",
" self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
" self.device = device\n",
" self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
" self.batch_size = batch_size\n",
" self.mimi.streaming_forever(batch_size)\n",
" self.lm_gen.streaming_forever(batch_size)\n",
"\n",
" def run(self, in_pcms: torch.Tensor):\n",
" device = self.lm_gen.lm_model.device\n",
" ntokens = 0\n",
" first_frame = True\n",
" chunks = [\n",
" c\n",
" for c in in_pcms.split(self.frame_size, dim=2)\n",
" if c.shape[-1] == self.frame_size\n",
" ]\n",
" start_time = time.time()\n",
" all_text = []\n",
" for chunk in chunks:\n",
" codes = self.mimi.encode(chunk)\n",
" if first_frame:\n",
" # Ensure that the first slice of codes is properly seen by the transformer\n",
" # as otherwise the first slice is replaced by the initial tokens.\n",
" tokens = self.lm_gen.step(codes)\n",
" first_frame = False\n",
" tokens = self.lm_gen.step(codes)\n",
" if tokens is None:\n",
" continue\n",
" assert tokens.shape[1] == 1\n",
" one_text = tokens[0, 0].cpu()\n",
" if one_text.item() not in [0, 3]:\n",
" text = self.text_tokenizer.id_to_piece(one_text.item())\n",
" text = text.replace(\"▁\", \" \")\n",
" all_text.append(text)\n",
" ntokens += 1\n",
" dt = time.time() - start_time\n",
" print(\n",
" f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
" )\n",
" return \"\".join(all_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 353,
"referenced_widgets": [
"0a5f6f887e2b4cd1990a0e9ec0153ed9",
"f7893826fcba4bdc87539589d669249b",
"8805afb12c484781be85082ff02dad13",
"97679c0d9ab44bed9a3456f2fcb541fd",
"d73c0321bed54a52b5e1da0a7788e32a",
"d67be13a920d4fc89e5570b5b29fc1d2",
"6b377c2d7bf945fb89e46c39d246a332",
"b82ff365c78e41ad8094b46daf79449d",
"477aa7fa82dc42d5bce6f1743c45d626",
"cbd288510c474430beb66f346f382c45",
"aafc347cdf28428ea6a7abe5b46b726f",
"fca09acd5d0d45468c8b04bfb2de7646",
"79e35214b51b4a9e9b3f7144b0b34f7b",
"89e9a37f69904bd48b954d627bff6687",
"57028789c78248a7b0ad4f031c9545c9",
"1150fcb427994c2984d4d0f4e4745fe5",
"e24b1fc52f294f849019c9b3befb613f",
"8724878682cf4c3ca992667c45009398",
"36a22c977d5242008871310133b7d2af",
"5b3683cad5cb4877b43fadd003edf97f",
"703f98272e4d469d8f27f5a465715dd8",
"9dbe02ef5fac41cfaee3d02946e65c88",
"37faa87ad03a4271992c21ce6a629e18",
"570c547e48cd421b814b2c5e028e4c0b",
"b173768580fc4c0a8e3abf272e4c363a",
"e57d1620f0a9427b85d8b4885ef4e8e3",
"5dd4474df70743498b616608182714dd",
"cc907676a65f4ad1bf68a77b4a00e89b",
"a34abc3b118e4305951a466919c28ff6",
"a77ccfcdb90146c7a63b4b2d232bc494",
"f7313e6e3a27475993cab3961d6ae363",
"39b47fad9c554839868fe9e4bbf7def2",
"14e9511ea0bd44c49f0cf3abf1a6d40e",
"a4ea8e0c4cac4d5e88b7e3f527e4fe90",
"571afc0f4b2840c9830d6b5a307ed1f9",
"6ec593cab5b64f0ea638bb175b9daa5c",
"77a52aed00ae408bb24524880e19ec8a",
"0b2de4b29b4b44fe9d96361a40c793d0",
"3c5b5fb1a5ac468a89c1058bd90cfb58",
"e53e0a2a240e43cfa562c89b3d703dea",
"35966343cf9249ef8bc028a0d5c5f97d",
"e36a37e0d41c47ccb8bc6d56c19fb17c",
"279ccf7de43847a1a6579c9182a46cc8",
"41b5d6ab0b7d43c790a55f125c0e7494"
]
},
"id": "UsQJdAgkLp9n",
"outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
},
"outputs": [],
"source": [
"device = \"cuda\"\n",
"# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
"checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
"mimi = checkpoint_info.get_mimi(device=device)\n",
"text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
"lm = checkpoint_info.get_moshi(device=device)\n",
"in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
"in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
"\n",
"stt_config = checkpoint_info.stt_config\n",
"pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
"pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
"in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
"in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
"\n",
"state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
"text = state.run(in_pcms)\n",
"print(textwrap.fill(text, width=100))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "CIAXs9oaPrtj",
"outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"\n",
"Audio(\"sample_fr_hibiki_crepes.mp3\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qkUZ6CBKOdTa"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "L4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,140 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0",
"metadata": {},
"outputs": [],
"source": [
"# Fast install, might break in the future.\n",
"!pip install 'sphn<0.2'\n",
"!pip install --no-deps \"moshi==0.2.11\"\n",
"# Slow install (will download torch and cuda), but future proof.\n",
"# !pip install \"moshi==0.2.11\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import sys\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from moshi.models.loaders import CheckpointInfo\n",
"from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
"\n",
"from IPython.display import display, Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"# Configuration\n",
"text = \"Hey there! How are you? I had the craziest day today.\"\n",
"voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n",
"print(f\"See https://huggingface.co/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"# Set everything up\n",
"checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
"tts_model = TTSModel.from_checkpoint_info(\n",
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n",
")\n",
"\n",
"# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
"entries = tts_model.prepare_script([text], padding_between=1)\n",
"voice_path = tts_model.get_voice_path(voice)\n",
"# CFG coef goes here because the model was trained with CFG distillation,\n",
"# so it's not _actually_ doing CFG at inference time.\n",
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
"condition_attributes = tts_model.make_condition_attributes(\n",
" [voice_path], cfg_coef=2.0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"print(\"Generating audio...\")\n",
"\n",
"pcms = []\n",
"def _on_frame(frame):\n",
" print(\"Step\", len(pcms), end=\"\\r\")\n",
" if (frame != -1).all():\n",
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
"\n",
"# You could also generate multiple audios at once by extending the following lists.\n",
"all_entries = [entries]\n",
"all_condition_attributes = [condition_attributes]\n",
"with tts_model.mimi.streaming(len(all_entries)):\n",
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
"\n",
"print(\"Done generating.\")\n",
"audio = np.concatenate(pcms, axis=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"display(\n",
" Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}