Compare commits
73 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 96042c9983 | |||
| 7e92c6f28a | |||
| 8e51c6eab9 | |||
| 34e58b8055 | |||
| f23d1be027 | |||
|
|
cf97f8d863 | ||
|
|
09468c239a | ||
|
|
07729ed47e | ||
|
|
af2283de3f | ||
|
|
7dc926d50c | ||
|
|
ab8e8c59b7 | ||
|
|
5f17114618 | ||
|
|
405a82ba3f | ||
|
|
3b584b100c | ||
|
|
a98eb94ade | ||
|
|
a2f031deb5 | ||
|
|
66a33c989f | ||
|
|
89a2ced839 | ||
|
|
baf0c75bba | ||
|
|
952319de90 | ||
|
|
6d3bb6b1f1 | ||
|
|
12dbe36b0b | ||
|
|
cafac63222 | ||
|
|
7336d7a3da | ||
|
|
70500c620e | ||
|
|
f8e97aa4f3 | ||
|
|
91a4d120cb | ||
|
|
bfc200f6ee | ||
|
|
f9739881e6 | ||
|
|
99599fa408 | ||
|
|
3a4165a84f | ||
|
|
e9bac066ea | ||
|
|
eae5e17975 | ||
|
|
c1d248abba | ||
|
|
c6f262346f | ||
|
|
3573ee90af | ||
|
|
25574aa104 | ||
|
|
1cd9529f65 | ||
|
|
0ee2354176 | ||
|
|
dc8bffabe0 | ||
|
|
5f8e924176 | ||
|
|
d3bed09f9a | ||
|
|
ef52b8ef0f | ||
|
|
d92e4c2695 | ||
|
|
6c1e9f12cf | ||
|
|
236df522b8 | ||
|
|
20cf8d7365 | ||
|
|
ae575a04c6 | ||
|
|
433dca3751 | ||
|
|
07ac744609 | ||
|
|
96ff217437 | ||
|
|
7294fbcc3a | ||
|
|
c4ef93770a | ||
|
|
395eaeae95 | ||
|
|
4985940aad | ||
|
|
0112245ef7 | ||
|
|
96eef33c4c | ||
|
|
7b5a01dfba | ||
|
|
8bd3f59631 | ||
|
|
473b179cc8 | ||
|
|
2198f1d660 | ||
|
|
a3ed93d16b | ||
|
|
ef864a6f38 | ||
|
|
dd5cbcbeef | ||
|
|
d7642ff1e9 | ||
|
|
d473deddaf | ||
|
|
3282de0559 | ||
|
|
5c37f42ff2 | ||
|
|
142a02f6da | ||
|
|
5549b61d1c | ||
|
|
35c4ea47d8 | ||
|
|
91fb68acc4 | ||
|
|
957edae092 |
83
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
83
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
name: Bug Report
|
||||||
|
description: You found a bug.
|
||||||
|
labels: ["bug", "triage"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
|
||||||
|
- type: dropdown
|
||||||
|
id: backend
|
||||||
|
attributes:
|
||||||
|
label: Backend impacted
|
||||||
|
description: Which backend is concerned with your bug report?
|
||||||
|
options:
|
||||||
|
- The PyTorch implementation
|
||||||
|
- The MLX implementation
|
||||||
|
- The Rust implementation
|
||||||
|
- Other / All
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating system
|
||||||
|
description: What is your operating system?
|
||||||
|
options:
|
||||||
|
- Linux
|
||||||
|
- Mac OS X
|
||||||
|
- Windows (unsupported)
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: hardware
|
||||||
|
attributes:
|
||||||
|
label: Hardware
|
||||||
|
description: What hardware are you using?
|
||||||
|
options:
|
||||||
|
- CPU
|
||||||
|
- GPU with CUDA
|
||||||
|
- Metal with MLX
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: Description
|
||||||
|
description: Provide a detailed description of your bug.
|
||||||
|
placeholder:
|
||||||
|
value:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: more_info
|
||||||
|
attributes:
|
||||||
|
label: Extra information
|
||||||
|
description: Please provide any other relevant information, such as log extracts, code etc.
|
||||||
|
placeholder:
|
||||||
|
value:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: env
|
||||||
|
attributes:
|
||||||
|
label: Environment
|
||||||
|
description: Please provide any other relevant information, such as log extracts, code etc.
|
||||||
|
placeholder:
|
||||||
|
value: |
|
||||||
|
Fill in the following information on your system.
|
||||||
|
- Operating system version:
|
||||||
|
|
||||||
|
If the backend impacted is PyTorch:
|
||||||
|
- Python version:
|
||||||
|
- PyTorch version:
|
||||||
|
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
|
||||||
|
- GPU model and memory:
|
||||||
|
|
||||||
|
If the backend is MLX:
|
||||||
|
- Mac model:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
40
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
name: Question
|
||||||
|
description: You have a question about the codebase, the paper, or the implementation.
|
||||||
|
labels: ["question", "triage"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Due diligence
|
||||||
|
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
|
||||||
|
options:
|
||||||
|
- label: I have done my due diligence in trying to find the answer myself.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: backend
|
||||||
|
attributes:
|
||||||
|
label: Topic
|
||||||
|
description: What is your question about?
|
||||||
|
options:
|
||||||
|
- The paper
|
||||||
|
- The PyTorch implementation
|
||||||
|
- The MLX implementation
|
||||||
|
- The Rust implementation
|
||||||
|
- Other / All
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: question
|
||||||
|
attributes:
|
||||||
|
label: Question
|
||||||
|
description: What is your question?
|
||||||
|
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
|
||||||
|
value:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
|
||||||
|
- [ ] Run pre-commit hook.
|
||||||
|
- [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
|
||||||
|
|
||||||
|
## PR Description
|
||||||
|
|
||||||
|
<!-- Description for the PR -->
|
||||||
28
.github/actions/moshi_build/action.yml
vendored
Executable file
28
.github/actions/moshi_build/action.yml
vendored
Executable file
|
|
@ -0,0 +1,28 @@
|
||||||
|
name: moshi_build
|
||||||
|
description: 'Build env.'
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.10.14'
|
||||||
|
- uses: actions/cache@v3
|
||||||
|
id: cache
|
||||||
|
with:
|
||||||
|
path: env
|
||||||
|
key: env-${{ hashFiles('moshi/pyproject.toml') }}
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m venv env
|
||||||
|
. env/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install moshi==0.2.7
|
||||||
|
pip install pre-commit
|
||||||
|
- name: Setup env
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
source env/bin/activate
|
||||||
|
pre-commit install
|
||||||
17
.github/workflows/precommit.yml
vendored
Normal file
17
.github/workflows/precommit.yml
vendored
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
name: precommit
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_precommit:
|
||||||
|
name: Run precommit
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- uses: ./.github/actions/moshi_build
|
||||||
|
- run: |
|
||||||
|
source env/bin/activate
|
||||||
|
pre-commit run --all-files
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -192,5 +192,4 @@ cython_debug/
|
||||||
# refer to https://docs.cursor.com/context/ignore-files
|
# refer to https://docs.cursor.com/context/ignore-files
|
||||||
.cursorignore
|
.cursorignore
|
||||||
.cursorindexingignore
|
.cursorindexingignore
|
||||||
bria.mp3
|
out*.wav
|
||||||
sample_fr_hibiki_crepes.mp3
|
|
||||||
|
|
|
||||||
22
.pre-commit-config.yaml
Normal file
22
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
repos:
|
||||||
|
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
|
||||||
|
- repo: https://github.com/kynan/nbstripout
|
||||||
|
rev: 0.8.1
|
||||||
|
hooks:
|
||||||
|
- id: nbstripout
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ["--maxkb=2048"]
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
# Ruff version.
|
||||||
|
rev: v0.11.7
|
||||||
|
hooks:
|
||||||
|
# Run the linter.
|
||||||
|
- id: ruff
|
||||||
|
types_or: [python, pyi] # Don't run on `jupyter` files
|
||||||
|
args: [--fix]
|
||||||
|
# Run the formatter.
|
||||||
|
- id: ruff-format
|
||||||
|
types_or: [python, pyi] # Don't run on `jupyter` files
|
||||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.analysis.typeCheckingMode": "standard"
|
||||||
|
}
|
||||||
58
CONTRIBUTING.md
Normal file
58
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Contributing to Delayed-Streams-Modeling
|
||||||
|
|
||||||
|
## Pull Requests
|
||||||
|
|
||||||
|
Delayed-Streams-Modeling is the implementation of a research paper.
|
||||||
|
Therefore, we do not plan on accepting many pull requests for new features.
|
||||||
|
However, we certainly welcome them for bug fixes.
|
||||||
|
|
||||||
|
1. Fork the repo and create your branch from `main`.
|
||||||
|
2. If you have changed APIs, update the documentation accordingly.
|
||||||
|
3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
|
||||||
|
4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
|
||||||
|
5. Accept the Contributor License Agreement (see after).
|
||||||
|
|
||||||
|
Note that in general, we will not accept refactoring of the code.
|
||||||
|
|
||||||
|
|
||||||
|
## Contributor License Agreement ("CLA")
|
||||||
|
|
||||||
|
In order to accept your pull request, we need you to submit a Contributor License Agreement.
|
||||||
|
|
||||||
|
If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle:
|
||||||
|
|
||||||
|
> I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
|
||||||
|
|
||||||
|
The full CLA is provided as follows:
|
||||||
|
|
||||||
|
> I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
|
||||||
|
> irrevocable license to use, modify, distribute, and sublicense my Contributions.
|
||||||
|
|
||||||
|
> I understand and accept that Contributions are limited to modifications, improvements, or changes
|
||||||
|
> to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
|
||||||
|
> review, accept, reject, or request changes to any Contributions I submit, and that submitting
|
||||||
|
> a pull request does not guarantee its inclusion in the project.
|
||||||
|
|
||||||
|
> By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
|
||||||
|
> reproduce, distribute, and create derivative works based on my Contributions.
|
||||||
|
> I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
|
||||||
|
> giving the Kyutai-labs full rights to file for and enforce patents.
|
||||||
|
> I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
|
||||||
|
> I confirm that my Contributions are original and that I have the legal right to grant this license.
|
||||||
|
> If my Contributions include third-party materials, I will ensure that I have the necessary permissions
|
||||||
|
> and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion.
|
||||||
|
|
||||||
|
> I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
|
||||||
|
> Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
|
||||||
|
> By submitting a pull request, I agree to be bound by these terms.
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
|
||||||
|
Please submit issues on our Github repository.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
By contributing to Delayed-Streams-Modeling, you agree that your contributions
|
||||||
|
will be licensed under the LICENSE-* files in the root directory of this source
|
||||||
|
tree. In particular, the rust code is licensed under APACHE, and the python code
|
||||||
|
under MIT.
|
||||||
56
FAQ.md
Normal file
56
FAQ.md
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
# FAQ
|
||||||
|
|
||||||
|
Here is the answer to a number of frequently asked questions.
|
||||||
|
|
||||||
|
### Torch compilation issues
|
||||||
|
|
||||||
|
With some PyTorch/triton versions, one might encounter compilation errors
|
||||||
|
like the following:
|
||||||
|
```
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
File "site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1153, in make_launcher
|
||||||
|
"launch_enter_hook": binary.__class__.launch_enter_hook,
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
|
||||||
|
```
|
||||||
|
|
||||||
|
If that's the case, you can disable torch compilation by setting the following
|
||||||
|
environment variable.
|
||||||
|
```bash
|
||||||
|
export NO_TORCH_COMPILE=1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issues installing the sentencepiece dependency
|
||||||
|
|
||||||
|
On some linux distributions (arch) or on macos, the local version of cmake can
|
||||||
|
be too recent for the sentencepiece dependency.
|
||||||
|
|
||||||
|
```
|
||||||
|
CMake Error at CMakeLists.txt:15 (cmake_minimum_required):
|
||||||
|
Compatibility with CMake < 3.5 has been removed from CMake.
|
||||||
|
```
|
||||||
|
|
||||||
|
You can either downgrade your cmake version, e.g. 3.31.0 on arch works or try
|
||||||
|
setting `CMAKE_POLICY_VERSION_MINIMUM=3.5`.
|
||||||
|
|
||||||
|
If you run into some errors when compiling the sentencepiece rust bindings,
|
||||||
|
these could also be due to gcc being too recent, e.g. gcc 15. You can get
|
||||||
|
around this by using gcc-13, e.g. by setting the following after installing
|
||||||
|
the proper gcc packages.
|
||||||
|
```bash
|
||||||
|
export CMAKE_C_COMPILER=/usr/bin/gcc-13
|
||||||
|
export CMAKE_CXX_COMPILER=/usr/bin/g++-13
|
||||||
|
CC=gcc-13 CXX=g++-13 cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively you can set `CXXFLAGS="-include cstdint"`, see this
|
||||||
|
[issue](https://github.com/google/sentencepiece/issues/1108).
|
||||||
|
|
||||||
|
### Will you release training code?
|
||||||
|
|
||||||
|
Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
|
||||||
|
This code has not been adapted to the Speech-To-Text and Text-To-Speech models
|
||||||
|
yet, but it should be a good starting point.
|
||||||
|
|
||||||
|
|
||||||
333
README.md
333
README.md
|
|
@ -1,91 +1,123 @@
|
||||||
# delayed-streams-modeling
|
# Delayed Streams Modeling: Kyutai STT & TTS
|
||||||
Delayed Streams Modeling (DSM) is a flexible formulation for streaming, multimodal sequence-to-sequence learning.
|
|
||||||
|
|
||||||
## Speech-to-text
|
This repo contains instructions and examples of how to run
|
||||||
|
[Kyutai Speech-To-Text](#kyutai-speech-to-text)
|
||||||
|
and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models.
|
||||||
|
These models are powered by delayed streams modeling (DSM),
|
||||||
|
a flexible formulation for streaming, multimodal sequence-to-sequence learning.
|
||||||
|
See also [Unmute](https://github.com/kyutai-labs/unmute), an voice AI system built using Kyutai STT and Kyutai TTS.
|
||||||
|
|
||||||
DSM can be used to build streaming speech-to-text models. We provide two such models
|
But wait, what is "Delayed Streams Modeling"? It is a technique for solving many streaming X-to-Y tasks (with X, Y in `{speech, text}`)
|
||||||
with a different delay between the audio input and the text output.
|
that formalize the approach we had with Moshi and Hibiki. A pre-print paper is coming soon!
|
||||||
- An English and French model with ~1b parameters using a 0.5 second delay,
|
|
||||||
`kyutai/stt-1b-en_fr`.
|
|
||||||
- An English only model with ~2.6b parameters using a 2.5 second delay,
|
|
||||||
`kyutai/stt-2.6b-en`.
|
|
||||||
|
|
||||||
These speech-to-text models have several advantages:
|
## Kyutai Speech-To-Text
|
||||||
- Easy batching for maximum efficiency: a H100 can process 400 streams in
|
|
||||||
real-time.
|
|
||||||
- Streaming inference: the models can process audio in chunks, which allows
|
|
||||||
for real-time transcription, and is great for interactive applications.
|
|
||||||
- Return word-level timestamps.
|
|
||||||
- Some models have a semantic Voice Activity Detection (VAD) component that
|
|
||||||
can be used to detect when the user is speaking. This is especially useful
|
|
||||||
for building voice agents.
|
|
||||||
|
|
||||||
More details can be found on the [project page](https://kyutai.org/next/stt).
|
<a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;">
|
||||||
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
You can retrieve the sample files used in the following snippets via:
|
|
||||||
```bash
|
|
||||||
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
|
||||||
wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3
|
|
||||||
```
|
|
||||||
|
|
||||||
### PyTorch implementation
|
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
|
||||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
|
||||||
</a>
|
</a>
|
||||||
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
|
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
**More details can be found on the [project page](https://kyutai.org/next/stt).**
|
||||||
|
|
||||||
|
Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps.
|
||||||
|
We provide two models:
|
||||||
|
- `kyutai/stt-1b-en_fr`, an English and French model with ~1B parameters, a 0.5 second delay, and a [semantic VAD](https://kyutai.org/next/stt#semantic-vad).
|
||||||
|
- `kyutai/stt-2.6b-en`, an English-only model with ~2.6B parameters and a 2.5 second delay.
|
||||||
|
|
||||||
|
These speech-to-text models have several advantages:
|
||||||
|
- Streaming inference: the models can process audio in chunks, which allows
|
||||||
|
for real-time transcription, and is great for interactive applications.
|
||||||
|
- Easy batching for maximum efficiency: a H100 can process 400 streams in
|
||||||
|
real-time.
|
||||||
|
- They return word-level timestamps.
|
||||||
|
- The 1B model has a semantic Voice Activity Detection (VAD) component that
|
||||||
|
can be used to detect when the user is speaking. This is especially useful
|
||||||
|
for building voice agents.
|
||||||
|
|
||||||
|
### Implementations overview
|
||||||
|
|
||||||
|
We provide different implementations of Kyutai STT for different use cases.
|
||||||
|
Here is how to choose which one to use:
|
||||||
|
|
||||||
|
- **PyTorch: for research and tinkering.**
|
||||||
|
If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
|
||||||
|
- **Rust: for production.**
|
||||||
|
If you want to serve Kyutai STT in a production setting, use our Rust server.
|
||||||
|
Our robust Rust server provides streaming access to the model over websockets.
|
||||||
|
We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x.
|
||||||
|
- **MLX: for on-device inference on iPhone and Mac.**
|
||||||
|
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
|
||||||
|
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>PyTorch implementation</summary>
|
||||||
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
|
||||||
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
|
</a>
|
||||||
|
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
|
||||||
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
|
||||||
|
[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb).
|
||||||
|
|
||||||
This requires the [moshi package](https://pypi.org/project/moshi/)
|
This requires the [moshi package](https://pypi.org/project/moshi/)
|
||||||
with version 0.2.5 or later, which can be installed via pip.
|
with version 0.2.6 or later, which can be installed via pip.
|
||||||
|
|
||||||
|
If you just want to run the model on a file, you can use `moshi.run_inference`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
If you have `uv` installed, you can skip the installation step and run directly:
|
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
|
||||||
```bash
|
and just prefix the command above with `uvx --with moshi`.
|
||||||
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
|
|
||||||
```
|
|
||||||
It will install the moshi package in a temporary environment and run the speech-to-text.
|
|
||||||
|
|
||||||
### MLX implementation
|
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
|
||||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
|
||||||
</a>
|
|
||||||
|
|
||||||
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
|
|
||||||
with version 0.2.5 or later, which can be installed via pip.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
uv run \
|
||||||
|
scripts/stt_from_file_pytorch.py \
|
||||||
|
--hf-repo kyutai/stt-2.6b-en \
|
||||||
|
audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
If you have `uv` installed, you can skip the installation step and run directly:
|
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
|
||||||
```bash
|
```bash
|
||||||
uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0
|
uv run scripts/evaluate_on_dataset.py \
|
||||||
|
--dataset meanwhile \
|
||||||
|
--hf-repo kyutai/stt-2.6b-en
|
||||||
```
|
```
|
||||||
It will install the moshi package in a temporary environment and run the speech-to-text.
|
|
||||||
|
|
||||||
### Rust implementation
|
Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model:
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
|
||||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
|
||||||
</a>
|
|
||||||
|
|
||||||
A standalone Rust example is provided in the `stt-rs` directory in this repo.
|
|
||||||
This can be used as follows:
|
|
||||||
```bash
|
```bash
|
||||||
cd stt-rs
|
uv run scripts/stt_from_file_pytorch_with_prompt.py \
|
||||||
cargo run --features cuda -r -- bria.mp3
|
--hf-repo kyutai/stt-2.6b-en \
|
||||||
|
--file bria.mp3 \
|
||||||
|
--prompt_file ./audio/loonah.mp3 \
|
||||||
|
--prompt_text "Loonah" \
|
||||||
|
--cut-prompt-transcript
|
||||||
|
```
|
||||||
|
Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt:
|
||||||
|
```
|
||||||
|
In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Rust server
|
Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language.
|
||||||
|
However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Rust server</summary>
|
||||||
|
|
||||||
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
||||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
The Rust implementation provides a server that can process multiple streaming
|
The Rust implementation provides a server that can process multiple streaming
|
||||||
queries in parallel. Dependening on the amount of memory on your GPU, you may
|
queries in parallel. Depending on the amount of memory on your GPU, you may
|
||||||
have to adjust the batch size from the config file. For a L40S GPU, a batch size
|
have to adjust the batch size from the config file. For a L40S GPU, a batch size
|
||||||
of 64 works well and requests can be processed at 3x real-time speed.
|
of 64 works well and requests can be processed at 3x real-time speed.
|
||||||
|
|
||||||
|
|
@ -100,24 +132,182 @@ cargo install --features cuda moshi-server
|
||||||
|
|
||||||
Then the server can be started via the following command using the config file
|
Then the server can be started via the following command using the config file
|
||||||
from this repository.
|
from this repository.
|
||||||
|
For `kyutai/stt-1b-en_fr`, use `configs/config-stt-en_fr.hf.toml`,
|
||||||
|
and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`,
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
moshi-server worker --config configs/config-stt-hf.toml
|
moshi-server worker --config configs/config-stt-en_fr-hf.toml
|
||||||
```
|
```
|
||||||
|
|
||||||
Once the server has started you can run a streaming inference with the following
|
Once the server has started you can transcribe audio from your microphone with the following script.
|
||||||
script.
|
|
||||||
```bash
|
```bash
|
||||||
uv run scripts/asr-streaming-query.py bria.mp3
|
uv run scripts/stt_from_mic_rust_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
We also provide a script for transcribing from an audio file.
|
||||||
|
```bash
|
||||||
|
uv run scripts/stt_from_file_rust_server.py audio/bria.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
The script limits the decoding speed to simulates real-time processing of the audio.
|
The script limits the decoding speed to simulates real-time processing of the audio.
|
||||||
Faster processing can be triggered by setting
|
Faster processing can be triggered by setting
|
||||||
the real-time factor, e.g. `--rtf 500` will process
|
the real-time factor, e.g. `--rtf 1000` will process
|
||||||
the data as fast as possible.
|
the data as fast as possible.
|
||||||
|
</details>
|
||||||
|
|
||||||
## Text-to-Speech
|
<details>
|
||||||
|
<summary>Rust standalone</summary>
|
||||||
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
|
||||||
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
We're in the process of open-sourcing our TTS models. Check back for updates!
|
A standalone Rust example script is provided in the `stt-rs` directory in this repo.
|
||||||
|
This can be used as follows:
|
||||||
|
```bash
|
||||||
|
cd stt-rs
|
||||||
|
cargo run --features cuda -r -- ../audio/bria.mp3
|
||||||
|
```
|
||||||
|
You can get the timestamps by adding the `--timestamps` flag, and see the output
|
||||||
|
of the semantic VAD by adding the `--vad` flag.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>MLX implementation</summary>
|
||||||
|
<a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
|
||||||
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
|
||||||
|
hardware acceleration on Apple silicon.
|
||||||
|
|
||||||
|
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
|
||||||
|
with version 0.2.6 or later, which can be installed via pip.
|
||||||
|
|
||||||
|
If you just want to run the model on a file, you can use `moshi_mlx.run_inference`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
|
||||||
|
```
|
||||||
|
|
||||||
|
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
|
||||||
|
and just prefix the command above with `uvx --with moshi-mlx`.
|
||||||
|
|
||||||
|
If you want to transcribe audio from your microphone, use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/stt_from_mic_mlx.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The MLX models can also be used in swift using the [moshi-swift
|
||||||
|
codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been
|
||||||
|
tested to work fine on an iPhone 16 Pro.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Kyutai Text-to-Speech
|
||||||
|
|
||||||
|
<a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;">
|
||||||
|
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/>
|
||||||
|
</a>
|
||||||
|
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
|
||||||
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
**More details can be found on the [project page](https://kyutai.org/next/tts).**
|
||||||
|
|
||||||
|
We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use:
|
||||||
|
|
||||||
|
- PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
|
||||||
|
- Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute.
|
||||||
|
- MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>PyTorch implementation</summary>
|
||||||
|
|
||||||
|
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
|
||||||
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From stdin, plays audio immediately
|
||||||
|
echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
|
||||||
|
|
||||||
|
# From text file to audio file
|
||||||
|
python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
The `tts_pytorch.py` script waits for all the text to be available before
|
||||||
|
starting the audio generation. A fully streaming implementation is available in
|
||||||
|
the `tts_pytorch_streaming.py` script, which can be used as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "Hey, how are you?" | python scripts/tts_pytorch_streaming.py audio_output.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
|
||||||
|
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
|
||||||
|
and just prefix the command above with `uvx --with moshi`.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Rust server</summary>
|
||||||
|
|
||||||
|
|
||||||
|
The Rust implementation provides a server that can process multiple streaming
|
||||||
|
queries in parallel.
|
||||||
|
|
||||||
|
Installing the Rust server is a bit tricky because it uses our Python implementation under the hood,
|
||||||
|
which also requires installing the Python dependencies.
|
||||||
|
Use the [start_tts.sh](https://github.com/kyutai-labs/unmute/blob/main/dockerless/start_tts.sh) script to properly install the Rust server.
|
||||||
|
If you already installed the `moshi-server` crate before and it's not working, you might need to force a reinstall by running `cargo uninstall moshi-server` first.
|
||||||
|
Feel free to open an issue if the installation is still broken.
|
||||||
|
|
||||||
|
Once installed, the server can be started via the following command using the config file
|
||||||
|
from this repository.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
moshi-server worker --config configs/config-tts.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the server has started you can connect to it using our script as follows:
|
||||||
|
```bash
|
||||||
|
# From stdin, plays audio immediately
|
||||||
|
echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
|
||||||
|
|
||||||
|
# From text file to audio file
|
||||||
|
python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>MLX implementation</summary>
|
||||||
|
|
||||||
|
[MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
|
||||||
|
hardware acceleration on Apple silicon.
|
||||||
|
|
||||||
|
Use our example script to run Kyutai TTS on MLX.
|
||||||
|
The script takes text from stdin or a file and can output to a file or stream the resulting audio.
|
||||||
|
When streaming the output, if the model is not fast enough to keep with
|
||||||
|
real-time, you can use the `--quantize 8` or `--quantize 4` flags to quantize
|
||||||
|
the model resulting in faster inference.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From stdin, plays audio immediately
|
||||||
|
echo "Hey, how are you?" | python scripts/tts_mlx.py - - --quantize 8
|
||||||
|
|
||||||
|
# From text file to audio file
|
||||||
|
python scripts/tts_mlx.py text_to_say.txt audio_output.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip.
|
||||||
|
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
|
||||||
|
and just prefix the command above with `uvx --with moshi-mlx`.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
Checkout the [Frequently Asked Questions](FAQ.md) section before opening an issue.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|
@ -127,3 +317,14 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo
|
||||||
the MIT license.
|
the MIT license.
|
||||||
|
|
||||||
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
|
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
|
||||||
|
|
||||||
|
## Developing
|
||||||
|
|
||||||
|
Install the [pre-commit hooks](https://pre-commit.com/) by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.
|
||||||
|
|
|
||||||
298
app/api_server.py
Normal file
298
app/api_server.py
Normal file
|
|
@ -0,0 +1,298 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OpenAI-Compatible Kyutai TTS API Server with Model Caching
|
||||||
|
Improved version that loads the model once and keeps it in memory
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import soundfile as sf
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global model variables - loaded once at startup
|
||||||
|
tts_model = None
|
||||||
|
device = None
|
||||||
|
sample_rate = None
|
||||||
|
|
||||||
|
class SpeechRequest(BaseModel):
|
||||||
|
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
|
||||||
|
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
|
||||||
|
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
|
||||||
|
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
|
||||||
|
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="OpenAI-Compatible TTS API (Cached)",
|
||||||
|
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
|
||||||
|
version="2.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
OUTPUT_DIR = Path("/app/api_output")
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
def load_tts_model():
|
||||||
|
"""Load TTS model once at startup and keep in memory"""
|
||||||
|
global tts_model, device, sample_rate
|
||||||
|
|
||||||
|
if tts_model is not None:
|
||||||
|
logger.info("TTS model already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
|
||||||
|
|
||||||
|
# Import Kyutai TTS modules
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Load the TTS model
|
||||||
|
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
|
||||||
|
|
||||||
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
|
checkpoint_info,
|
||||||
|
n_q=32,
|
||||||
|
temp=0.6,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sample rate
|
||||||
|
sample_rate = tts_model.mimi.sample_rate
|
||||||
|
|
||||||
|
logger.info(f"✅ TTS model loaded successfully!")
|
||||||
|
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
|
||||||
|
logger.info(f" Device: {device}")
|
||||||
|
logger.info(f" Sample Rate: {sample_rate}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to load TTS model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
|
||||||
|
"""Generate audio using cached TTS model"""
|
||||||
|
global tts_model, device, sample_rate
|
||||||
|
|
||||||
|
if tts_model is None:
|
||||||
|
raise HTTPException(status_code=500, detail="TTS model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
|
||||||
|
|
||||||
|
# Prepare the script (text input)
|
||||||
|
entries = tts_model.prepare_script([text], padding_between=1)
|
||||||
|
|
||||||
|
# Voice mapping for OpenAI compatibility
|
||||||
|
voice_mapping = {
|
||||||
|
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
|
||||||
|
}
|
||||||
|
|
||||||
|
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
voice_path = tts_model.get_voice_path(selected_voice)
|
||||||
|
except:
|
||||||
|
# Fallback to default if voice not found
|
||||||
|
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||||
|
|
||||||
|
# Prepare condition attributes
|
||||||
|
condition_attributes = tts_model.make_condition_attributes(
|
||||||
|
[voice_path], cfg_coef=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
pcms = []
|
||||||
|
|
||||||
|
def on_frame(frame):
|
||||||
|
if (frame != -1).all():
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
|
||||||
|
|
||||||
|
all_entries = [entries]
|
||||||
|
all_condition_attributes = [condition_attributes]
|
||||||
|
|
||||||
|
with tts_model.mimi.streaming(len(all_entries)):
|
||||||
|
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
|
||||||
|
|
||||||
|
# Concatenate all audio frames
|
||||||
|
if pcms:
|
||||||
|
import numpy as np
|
||||||
|
audio = np.concatenate(pcms, axis=-1)
|
||||||
|
|
||||||
|
# Apply speed adjustment if needed
|
||||||
|
if speed != 1.0:
|
||||||
|
# Simple speed adjustment by resampling
|
||||||
|
from scipy import signal
|
||||||
|
audio_length = len(audio)
|
||||||
|
new_length = int(audio_length / speed)
|
||||||
|
audio = signal.resample(audio, new_length)
|
||||||
|
|
||||||
|
# Convert to bytes
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
|
||||||
|
audio_bytes.seek(0)
|
||||||
|
|
||||||
|
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
|
||||||
|
return audio_bytes.read()
|
||||||
|
else:
|
||||||
|
raise Exception("No audio frames generated")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ TTS generation error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
|
||||||
|
|
||||||
|
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
|
||||||
|
"""Convert WAV audio to requested format using ffmpeg"""
|
||||||
|
try:
|
||||||
|
if output_format == "wav":
|
||||||
|
return audio_wav_bytes
|
||||||
|
|
||||||
|
# Use ffmpeg to convert
|
||||||
|
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
input=audio_wav_bytes,
|
||||||
|
capture_output=True,
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Audio conversion failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
|
||||||
|
|
||||||
|
@app.post("/v1/audio/speech")
|
||||||
|
async def create_speech(request: SpeechRequest):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible audio speech endpoint
|
||||||
|
Uses cached TTS model for fast generation
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Generate audio with cached model
|
||||||
|
audio_wav_bytes = generate_audio_fast(
|
||||||
|
text=request.input,
|
||||||
|
voice=request.voice,
|
||||||
|
speed=request.speed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to requested format
|
||||||
|
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
|
||||||
|
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
|
||||||
|
|
||||||
|
# Set appropriate content type
|
||||||
|
content_types = {
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"opus": "audio/opus",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"pcm": "audio/pcm"
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=audio_data,
|
||||||
|
media_type=content_types.get(request.response_format, "audio/wav"),
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"X-Generation-Time": str(generation_time)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Speech generation failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def list_models():
|
||||||
|
"""List available models (OpenAI-compatible)"""
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "tts-1",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677610602,
|
||||||
|
"owned_by": "kyutai",
|
||||||
|
"permission": [],
|
||||||
|
"root": "tts-1",
|
||||||
|
"parent": None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677610602,
|
||||||
|
"owned_by": "kyutai",
|
||||||
|
"permission": [],
|
||||||
|
"root": "tts-1-hd",
|
||||||
|
"parent": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint with model status"""
|
||||||
|
model_loaded = tts_model is not None
|
||||||
|
return {
|
||||||
|
"status": "healthy" if model_loaded else "loading",
|
||||||
|
"model_loaded": model_loaded,
|
||||||
|
"cuda_available": torch.cuda.is_available(),
|
||||||
|
"device": str(device) if device else None,
|
||||||
|
"service": "kyutai-tts-openai-compatible-cached"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/reload-model")
|
||||||
|
async def reload_model():
|
||||||
|
"""Reload the TTS model (admin endpoint)"""
|
||||||
|
global tts_model
|
||||||
|
try:
|
||||||
|
tts_model = None
|
||||||
|
load_tts_model()
|
||||||
|
return {"status": "success", "message": "Model reloaded successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Load model on startup"""
|
||||||
|
logger.info("🚀 Starting TTS API server with model caching...")
|
||||||
|
load_tts_model()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
67
app/dependency_check.py
Normal file
67
app/dependency_check.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Check if all Kyutai TTS dependencies are properly installed
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def check_dependencies():
|
||||||
|
print("🔍 Checking Kyutai TTS Dependencies")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
|
"numpy",
|
||||||
|
"einops",
|
||||||
|
"transformers",
|
||||||
|
"accelerate",
|
||||||
|
"soundfile",
|
||||||
|
"librosa",
|
||||||
|
"huggingface_hub",
|
||||||
|
"moshi",
|
||||||
|
"sphn"
|
||||||
|
]
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
installed = []
|
||||||
|
|
||||||
|
for dep in dependencies:
|
||||||
|
try:
|
||||||
|
__import__(dep)
|
||||||
|
installed.append(dep)
|
||||||
|
print(f"✓ {dep}")
|
||||||
|
except ImportError as e:
|
||||||
|
missing.append((dep, str(e)))
|
||||||
|
print(f"✗ {dep}: {e}")
|
||||||
|
|
||||||
|
print(f"\n📊 Summary:")
|
||||||
|
print(f"✓ Installed: {len(installed)}")
|
||||||
|
print(f"✗ Missing: {len(missing)}")
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
print(f"\n🔧 To fix missing dependencies:")
|
||||||
|
for dep, error in missing:
|
||||||
|
print(f"pip install {dep}")
|
||||||
|
|
||||||
|
print(f"\n🧪 Testing Kyutai TTS imports:")
|
||||||
|
try:
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
print("✓ CheckpointInfo import successful")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ CheckpointInfo import failed: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||||
|
print("✓ TTSModel imports successful")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ TTSModel imports failed: {e}")
|
||||||
|
|
||||||
|
return len(missing) == 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = check_dependencies()
|
||||||
|
if success:
|
||||||
|
print("\n🎉 All dependencies are installed correctly!")
|
||||||
|
else:
|
||||||
|
print("\n❌ Some dependencies are missing. Please install them first.")
|
||||||
|
sys.exit(1)
|
||||||
59
app/scripts/tts_runner.py
Normal file
59
app/scripts/tts_runner.py
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Kyutai TTS PyTorch Runner
|
||||||
|
Dockerized implementation for text-to-speech generation
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
|
||||||
|
parser.add_argument('input_file', help='Input text file or "-" for stdin')
|
||||||
|
parser.add_argument('output_file', help='Output audio file')
|
||||||
|
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
|
||||||
|
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Using device: {args.device}")
|
||||||
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||||
|
|
||||||
|
# Handle stdin input
|
||||||
|
if args.input_file == '-':
|
||||||
|
# Read from stdin and create temporary file
|
||||||
|
text = sys.stdin.read().strip()
|
||||||
|
temp_file = '/tmp/temp_input.txt'
|
||||||
|
with open(temp_file, 'w') as f:
|
||||||
|
f.write(text)
|
||||||
|
input_file = temp_file
|
||||||
|
else:
|
||||||
|
input_file = args.input_file
|
||||||
|
|
||||||
|
# Check if the original TTS script exists
|
||||||
|
tts_script = Path('/app/scripts/tts_pytorch.py')
|
||||||
|
if tts_script.exists():
|
||||||
|
print("Using original TTS script from Kyutai repository")
|
||||||
|
import subprocess
|
||||||
|
cmd = ['python', str(tts_script), input_file, args.output_file]
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
else:
|
||||||
|
print("Using moshi package for TTS generation")
|
||||||
|
import subprocess
|
||||||
|
cmd = [
|
||||||
|
'python', '-m', 'moshi.run_inference',
|
||||||
|
'--hf-repo', args.model,
|
||||||
|
input_file,
|
||||||
|
args.output_file
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"Error: {result.stderr}")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"Audio generated: {args.output_file}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
EOF
|
||||||
BIN
audio/bria.mp3
Normal file
BIN
audio/bria.mp3
Normal file
Binary file not shown.
BIN
audio/loona.mp3
Normal file
BIN
audio/loona.mp3
Normal file
Binary file not shown.
BIN
audio/sample_fr_hibiki_crepes.mp3
Normal file
BIN
audio/sample_fr_hibiki_crepes.mp3
Normal file
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
||||||
static_dir = "./static/"
|
static_dir = "./static/"
|
||||||
log_dir = "$HOME/tmp/tts-logs"
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
instance_name = "tts"
|
instance_name = "tts"
|
||||||
authorized_ids = ["open_token"]
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
[modules.asr]
|
[modules.asr]
|
||||||
path = "/api/asr-streaming"
|
path = "/api/asr-streaming"
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
static_dir = "./static/"
|
static_dir = "./static/"
|
||||||
log_dir = "$HOME/tmp/tts-logs"
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
instance_name = "tts"
|
instance_name = "tts"
|
||||||
authorized_ids = ["open_token"]
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
[modules.asr]
|
[modules.asr]
|
||||||
path = "/api/asr-streaming"
|
path = "/api/asr-streaming"
|
||||||
20
configs/config-tts.toml
Normal file
20
configs/config-tts.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
static_dir = "./static/"
|
||||||
|
log_dir = "$HOME/tmp/tts-logs"
|
||||||
|
instance_name = "tts"
|
||||||
|
authorized_ids = ["public_token"]
|
||||||
|
|
||||||
|
[modules.tts_py]
|
||||||
|
type = "Py"
|
||||||
|
path = "/api/tts_streaming"
|
||||||
|
text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
|
||||||
|
batch_size = 8 # Adjust to your GPU memory capacity
|
||||||
|
text_bos_token = 1
|
||||||
|
|
||||||
|
[modules.tts_py.py]
|
||||||
|
log_folder = "$HOME/tmp/moshi-server-logs"
|
||||||
|
voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
|
||||||
|
default_voice = "unmute-prod-website/default_voice.wav"
|
||||||
|
cfg_coef = 2.0
|
||||||
|
cfg_is_no_text = true
|
||||||
|
padding_between = 1
|
||||||
|
n_q = 24
|
||||||
78
install.sh
Normal file
78
install.sh
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Set environment variables
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
export PYTHONUNBUFFERED=1
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
apt-get update && apt-get install -y \
|
||||||
|
wget \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
build-essential \
|
||||||
|
libsndfile1 \
|
||||||
|
ffmpeg \
|
||||||
|
sox \
|
||||||
|
alsa-utils \
|
||||||
|
pulseaudio \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
|
# Install Python dependencies first (for better caching)
|
||||||
|
pip install --no-cache-dir --upgrade pip
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
apt install python3.12-venv python3.12-dev
|
||||||
|
python3.12 -m venv ~/venv-tts-kyutai
|
||||||
|
source ~/venv-tts-kyutai/bin/activate
|
||||||
|
|
||||||
|
# Install Python dependencies first (for better caching)
|
||||||
|
pip install --no-cache-dir --upgrade pip
|
||||||
|
|
||||||
|
# Install PyTorch with CUDA support for Python 3.12
|
||||||
|
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
|
||||||
|
# Install core dependencies
|
||||||
|
pip install --no-cache-dir \
|
||||||
|
numpy \
|
||||||
|
scipy \
|
||||||
|
librosa \
|
||||||
|
soundfile \
|
||||||
|
huggingface_hub \
|
||||||
|
einops \
|
||||||
|
transformers \
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
# Install API dependencies
|
||||||
|
pip install --no-cache-dir \
|
||||||
|
fastapi \
|
||||||
|
uvicorn[standard] \
|
||||||
|
python-multipart \
|
||||||
|
pydantic
|
||||||
|
|
||||||
|
# Install moshi package with all dependencies (following Colab notebook)
|
||||||
|
pip install --no-cache-dir 'sphn<0.2'
|
||||||
|
pip install --no-cache-dir "moshi==0.2.8"
|
||||||
|
|
||||||
|
# Create directories for input/output
|
||||||
|
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||||
|
|
||||||
|
# Download the Kyutai delayed-streams-modeling repository
|
||||||
|
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
|
||||||
|
|
||||||
|
# Copy the TTS script from the repository
|
||||||
|
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
|
||||||
|
|
||||||
|
# Create directories for input/output
|
||||||
|
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||||
|
|
||||||
|
# Download the Kyutai delayed-streams-modeling repository
|
||||||
|
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
|
||||||
|
|
||||||
|
# Copy the TTS script from the repository
|
||||||
|
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
|
||||||
|
|
||||||
|
# Create directories for input/output
|
||||||
|
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||||
|
|
||||||
|
# Start TTS-Server
|
||||||
|
python /app/api_server.py
|
||||||
|
|
@ -1,131 +0,0 @@
|
||||||
# /// script
|
|
||||||
# requires-python = ">=3.12"
|
|
||||||
# dependencies = [
|
|
||||||
# "msgpack",
|
|
||||||
# "numpy",
|
|
||||||
# "sphn",
|
|
||||||
# "websockets",
|
|
||||||
# ]
|
|
||||||
# ///
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import msgpack
|
|
||||||
import sphn
|
|
||||||
import struct
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
# Desired audio properties
|
|
||||||
TARGET_SAMPLE_RATE = 24000
|
|
||||||
TARGET_CHANNELS = 1 # Mono
|
|
||||||
HEADERS = {"kyutai-api-key": "open_token"}
|
|
||||||
all_text = []
|
|
||||||
transcript = []
|
|
||||||
finished = False
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_process_audio(file_path):
|
|
||||||
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
|
|
||||||
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
|
|
||||||
return pcm_data[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def receive_messages(websocket):
|
|
||||||
global all_text
|
|
||||||
global transcript
|
|
||||||
global finished
|
|
||||||
try:
|
|
||||||
async for message in websocket:
|
|
||||||
data = msgpack.unpackb(message, raw=False)
|
|
||||||
if data["type"] == "Step":
|
|
||||||
continue
|
|
||||||
print("received:", data)
|
|
||||||
if data["type"] == "Word":
|
|
||||||
all_text.append(data["text"])
|
|
||||||
transcript.append({
|
|
||||||
"speaker": "SPEAKER_00",
|
|
||||||
"text": data["text"],
|
|
||||||
"timestamp": [data["start_time"], data["start_time"]],
|
|
||||||
})
|
|
||||||
if data["type"] == "EndWord":
|
|
||||||
if len(transcript) > 0:
|
|
||||||
transcript[-1]["timestamp"][1] = data["stop_time"]
|
|
||||||
if data["type"] == "Marker":
|
|
||||||
print("Received marker, stopping stream.")
|
|
||||||
break
|
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
print("Connection closed while receiving messages.")
|
|
||||||
finished = True
|
|
||||||
|
|
||||||
|
|
||||||
async def send_messages(websocket, rtf: float):
|
|
||||||
global finished
|
|
||||||
audio_data = load_and_process_audio(args.in_file)
|
|
||||||
try:
|
|
||||||
# Start with a second of silence
|
|
||||||
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
|
|
||||||
chunk_size = 1920 # Send data in chunks
|
|
||||||
start_time = time.time()
|
|
||||||
for i in range(0, len(audio_data), chunk_size):
|
|
||||||
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
expected_send_time = start_time + (i + 1) / 24000 / rtf
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time < expected_send_time:
|
|
||||||
await asyncio.sleep(expected_send_time - current_time)
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(0.001)
|
|
||||||
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
for _ in range(35):
|
|
||||||
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
|
|
||||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
|
||||||
await websocket.send(msg)
|
|
||||||
while True:
|
|
||||||
if finished:
|
|
||||||
break
|
|
||||||
await asyncio.sleep(1.0)
|
|
||||||
# Keep the connection alive as there is a 20s timeout on the rust side.
|
|
||||||
await websocket.ping()
|
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
print("Connection closed while sending messages.")
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio(url: str, rtf: float):
|
|
||||||
"""Stream audio data to a WebSocket server."""
|
|
||||||
|
|
||||||
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
|
|
||||||
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
|
||||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
|
||||||
await asyncio.gather(send_task, receive_task)
|
|
||||||
print("exiting")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("in_file")
|
|
||||||
parser.add_argument("--transcript")
|
|
||||||
parser.add_argument(
|
|
||||||
"--url",
|
|
||||||
help="The url of the server to which to send the audio",
|
|
||||||
default="ws://127.0.0.1:8080",
|
|
||||||
)
|
|
||||||
parser.add_argument("--rtf", type=float, default=1.01)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
url = f"{args.url}/api/asr-streaming"
|
|
||||||
asyncio.run(stream_audio(url, args.rtf))
|
|
||||||
print(" ".join(all_text))
|
|
||||||
if args.transcript is not None:
|
|
||||||
with open(args.transcript, "w") as fobj:
|
|
||||||
json.dump({"transcript": transcript}, fobj, indent=4)
|
|
||||||
387
scripts/stt_evaluate_on_dataset.py
Normal file
387
scripts/stt_evaluate_on_dataset.py
Normal file
|
|
@ -0,0 +1,387 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "datasets",
|
||||||
|
# "jiwer==3.1.0",
|
||||||
|
# "julius",
|
||||||
|
# "librosa",
|
||||||
|
# "moshi",
|
||||||
|
# "openai-whisper",
|
||||||
|
# "soundfile",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
"""
|
||||||
|
Example implementation of the streaming STT example. Here we group
|
||||||
|
test utterances in batches (pre- and post-padded with silence) and
|
||||||
|
and then feed these batches into the streaming STT model frame-by-frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The outputs I get on my H100 using this code with the 2.6B model,
|
||||||
|
# bsz 32:
|
||||||
|
|
||||||
|
# LibriVox === cer: 4.09% wer: 7.33% corpus_wer: 6.78% RTF = 52.72
|
||||||
|
# Ami === cer: 15.99% wer: 18.78% corpus_wer: 12.20% RTF = 28.37
|
||||||
|
# LibriSpeech other === cer: 2.31% wer: 5.24% corpus_wer: 4.33% RTF = 44.76
|
||||||
|
# LibriSpeech clean === cer: 0.67% wer: 1.95% corpus_wer: 1.69% RTF = 68.19
|
||||||
|
# Tedlium (short) === cer: 2.15% wer: 3.65% corpus_wer: 3.33% RTF = 67.44
|
||||||
|
# spgispeech === cer: 0.99% wer: 2.00% corpus_wer: 2.03% RTF = 78.64
|
||||||
|
# gigaspeech === cer: 6.80% wer: 11.31% corpus_wer: 9.81% RTF = 64.04
|
||||||
|
# earnings22 (short) === cer: 12.63% wer: 15.70% corpus_wer: 11.02% RTF = 50.13
|
||||||
|
|
||||||
|
# Meanwhile === cer: 2.02% wer: 5.50% corpus_wer: 5.60% RTF = 69.19
|
||||||
|
# Tedlium (long) == cer: 1.53% wer: 2.56% corpus_wer: 2.97% RTF = 33.92
|
||||||
|
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
|
||||||
|
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jiwer
|
||||||
|
import julius
|
||||||
|
import moshi.models
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
from whisper.normalizers import EnglishTextNormalizer
|
||||||
|
|
||||||
|
_NORMALIZER = EnglishTextNormalizer()
|
||||||
|
|
||||||
|
|
||||||
|
def get_text(sample):
|
||||||
|
possible_keys = [
|
||||||
|
"text",
|
||||||
|
"sentence",
|
||||||
|
"normalized_text",
|
||||||
|
"transcript",
|
||||||
|
"transcription",
|
||||||
|
]
|
||||||
|
for key in possible_keys:
|
||||||
|
if key in sample:
|
||||||
|
return sample[key]
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected transcript column of either {possible_keys}."
|
||||||
|
f"Got sample with keys: {', '.join(sample.keys())}. Ensure a text column name is present in the dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# The two functions below are adapted from https://github.com/huggingface/open_asr_leaderboard/blob/main/normalizer/data_utils.py
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(batch):
|
||||||
|
batch["original_text"] = get_text(batch)
|
||||||
|
batch["norm_text"] = _NORMALIZER(batch["original_text"])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def is_target_text_in_range(ref):
|
||||||
|
if ref.strip() == "ignore time segment in scoring":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return ref.strip() != ""
|
||||||
|
|
||||||
|
|
||||||
|
# End of the adapted part
|
||||||
|
|
||||||
|
|
||||||
|
class AsrMetrics:
|
||||||
|
def __init__(self):
|
||||||
|
self.cer_sum = 0.0
|
||||||
|
self.wer_sum = 0.0
|
||||||
|
self.errors_sum = 0.0
|
||||||
|
self.total_words_sum = 0.0
|
||||||
|
self.num_sequences = 0.0
|
||||||
|
|
||||||
|
def update(self, hyp: str, ref: str) -> None:
|
||||||
|
normalized_ref = _NORMALIZER(ref)
|
||||||
|
normalized_hyp = _NORMALIZER(hyp)
|
||||||
|
|
||||||
|
this_wer = jiwer.wer(normalized_ref, normalized_hyp)
|
||||||
|
this_cer = jiwer.cer(normalized_ref, normalized_hyp)
|
||||||
|
measures = jiwer.compute_measures(normalized_ref, normalized_hyp)
|
||||||
|
|
||||||
|
self.wer_sum += this_wer
|
||||||
|
self.cer_sum += this_cer
|
||||||
|
self.errors_sum += (
|
||||||
|
measures["substitutions"] + measures["deletions"] + measures["insertions"]
|
||||||
|
)
|
||||||
|
self.total_words_sum += (
|
||||||
|
measures["substitutions"] + measures["deletions"] + measures["hits"]
|
||||||
|
)
|
||||||
|
self.num_sequences += 1
|
||||||
|
|
||||||
|
def compute(self) -> dict:
|
||||||
|
assert self.num_sequences > 0, (
|
||||||
|
"Unable to compute with total number of comparisons <= 0"
|
||||||
|
) # type: ignore
|
||||||
|
return {
|
||||||
|
"cer": (self.cer_sum / self.num_sequences),
|
||||||
|
"wer": (self.wer_sum / self.num_sequences),
|
||||||
|
"corpus_wer": (self.errors_sum / self.total_words_sum),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
result = self.compute()
|
||||||
|
return " ".join(f"{k}: {100 * v:.2f}%" for k, v in result.items())
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
def __init__(self):
|
||||||
|
self.total = 0
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._start_time = time.perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *_):
|
||||||
|
self.total += time.perf_counter() - self._start_time
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class _DatasetInfo:
|
||||||
|
alias: str
|
||||||
|
|
||||||
|
name: str
|
||||||
|
config: str
|
||||||
|
split: str = "test"
|
||||||
|
|
||||||
|
|
||||||
|
_DATASETS = [
|
||||||
|
# Long-form datasets from distil-whisper
|
||||||
|
_DatasetInfo("rev16", "distil-whisper/rev16", "whisper_subset"),
|
||||||
|
_DatasetInfo("earnings21", "distil-whisper/earnings21", "full"),
|
||||||
|
_DatasetInfo("earnings22", "distil-whisper/earnings22", "full"),
|
||||||
|
_DatasetInfo("tedlium", "distil-whisper/tedlium-long-form", None),
|
||||||
|
_DatasetInfo("meanwhile", "distil-whisper/meanwhile", None),
|
||||||
|
# Short-form datasets from OpenASR leaderboard
|
||||||
|
_DatasetInfo("ami", "hf-audio/esb-datasets-test-only-sorted", "ami"),
|
||||||
|
_DatasetInfo(
|
||||||
|
"librispeech.clean",
|
||||||
|
"hf-audio/esb-datasets-test-only-sorted",
|
||||||
|
"librispeech",
|
||||||
|
split="test.clean",
|
||||||
|
),
|
||||||
|
_DatasetInfo(
|
||||||
|
"librispeech.other",
|
||||||
|
"hf-audio/esb-datasets-test-only-sorted",
|
||||||
|
"librispeech",
|
||||||
|
split="test.other",
|
||||||
|
),
|
||||||
|
_DatasetInfo("voxpopuli", "hf-audio/esb-datasets-test-only-sorted", "voxpopuli"),
|
||||||
|
_DatasetInfo("spgispeech", "hf-audio/esb-datasets-test-only-sorted", "spgispeech"),
|
||||||
|
_DatasetInfo("gigaspeech", "hf-audio/esb-datasets-test-only-sorted", "gigaspeech"),
|
||||||
|
_DatasetInfo("tedlium-short", "hf-audio/esb-datasets-test-only-sorted", "tedlium"),
|
||||||
|
_DatasetInfo(
|
||||||
|
"earnings22-short", "hf-audio/esb-datasets-test-only-sorted", "earnings22"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
DATASET_MAP = {dataset.alias: dataset for dataset in _DATASETS}
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(args) -> Dataset:
|
||||||
|
if args.dataset not in DATASET_MAP:
|
||||||
|
raise RuntimeError(f"Unknown dataset: {args.dataset}")
|
||||||
|
|
||||||
|
info = DATASET_MAP[args.dataset]
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
info.name,
|
||||||
|
info.config,
|
||||||
|
split=info.split,
|
||||||
|
cache_dir=args.hf_cache_dir,
|
||||||
|
streaming=False,
|
||||||
|
token=True,
|
||||||
|
)
|
||||||
|
dataset = dataset.map(normalize)
|
||||||
|
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def get_padded_batch(
|
||||||
|
audios: list[tuple[torch.Tensor, int]],
|
||||||
|
before_padding: float,
|
||||||
|
after_padding: float,
|
||||||
|
audio_encoder,
|
||||||
|
):
|
||||||
|
sample_rate = audio_encoder.sample_rate
|
||||||
|
|
||||||
|
max_len = 0
|
||||||
|
batch = []
|
||||||
|
durations = []
|
||||||
|
for audio, sr in audios:
|
||||||
|
durations.append(audio.shape[-1] / sr)
|
||||||
|
audio = julius.resample_frac(audio, int(sr), int(sample_rate))
|
||||||
|
audio = torch.nn.functional.pad(
|
||||||
|
audio, (int(before_padding * sample_rate), int(after_padding * sample_rate))
|
||||||
|
)
|
||||||
|
max_len = max(max_len, audio.shape[-1])
|
||||||
|
batch.append(audio)
|
||||||
|
|
||||||
|
target = max_len
|
||||||
|
if target % audio_encoder.frame_size != 0:
|
||||||
|
target = target + (
|
||||||
|
audio_encoder.frame_size - max_len % audio_encoder.frame_size
|
||||||
|
)
|
||||||
|
padded_batch = torch.stack(
|
||||||
|
[
|
||||||
|
torch.nn.functional.pad(audio, (0, target - audio.shape[-1]))
|
||||||
|
for audio in batch
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return padded_batch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def streaming_transcribe(
|
||||||
|
padded_batch: torch.Tensor,
|
||||||
|
mimi,
|
||||||
|
lm_gen,
|
||||||
|
):
|
||||||
|
bsz = padded_batch.shape[0]
|
||||||
|
|
||||||
|
text_tokens_acc = []
|
||||||
|
|
||||||
|
with mimi.streaming(bsz), lm_gen.streaming(bsz):
|
||||||
|
for offset in range(0, padded_batch.shape[-1], mimi.frame_size):
|
||||||
|
audio_chunk = padded_batch[:, offset : offset + mimi.frame_size]
|
||||||
|
audio_chunk = audio_chunk[:, None, :]
|
||||||
|
|
||||||
|
audio_tokens = mimi.encode(audio_chunk)
|
||||||
|
text_tokens = lm_gen.step(audio_tokens)
|
||||||
|
if text_tokens is not None:
|
||||||
|
text_tokens_acc.append(text_tokens)
|
||||||
|
|
||||||
|
return torch.concat(text_tokens_acc, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_inference(
|
||||||
|
dataset,
|
||||||
|
mimi,
|
||||||
|
lm_gen,
|
||||||
|
tokenizer,
|
||||||
|
padding_token_id,
|
||||||
|
before_padding_sec,
|
||||||
|
after_padding_sec,
|
||||||
|
):
|
||||||
|
metrics = AsrMetrics()
|
||||||
|
audio_time = 0.0
|
||||||
|
inference_timer = Timer()
|
||||||
|
|
||||||
|
for batch in tqdm.tqdm(dataset.iter(args.batch_size)):
|
||||||
|
audio_data = list(
|
||||||
|
zip(
|
||||||
|
[torch.tensor(x["array"]).float() for x in batch["audio"]],
|
||||||
|
[x["sampling_rate"] for x in batch["audio"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_time += sum(audio.shape[-1] / sr for (audio, sr) in audio_data)
|
||||||
|
|
||||||
|
gt_transcripts = batch["original_text"]
|
||||||
|
|
||||||
|
padded_batch = get_padded_batch(
|
||||||
|
audio_data,
|
||||||
|
before_padding=before_padding_sec,
|
||||||
|
after_padding=after_padding_sec,
|
||||||
|
audio_encoder=mimi,
|
||||||
|
)
|
||||||
|
padded_batch = padded_batch.cuda()
|
||||||
|
|
||||||
|
with inference_timer:
|
||||||
|
text_tokens = streaming_transcribe(
|
||||||
|
padded_batch,
|
||||||
|
mimi=mimi,
|
||||||
|
lm_gen=lm_gen,
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_index in range(text_tokens.shape[0]):
|
||||||
|
utterance_tokens = text_tokens[batch_index, ...]
|
||||||
|
utterance_tokens = utterance_tokens[utterance_tokens > padding_token_id]
|
||||||
|
text = tokenizer.decode(utterance_tokens.cpu().numpy().tolist())
|
||||||
|
metrics.update(hyp=text, ref=gt_transcripts[batch_index])
|
||||||
|
|
||||||
|
return metrics, inference_timer.total, audio_time
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
||||||
|
args.hf_repo,
|
||||||
|
moshi_weights=args.moshi_weight,
|
||||||
|
mimi_weights=args.mimi_weight,
|
||||||
|
tokenizer=args.tokenizer,
|
||||||
|
config_path=args.config_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
mimi = info.get_mimi(device=args.device)
|
||||||
|
tokenizer = info.get_text_tokenizer()
|
||||||
|
lm = info.get_moshi(
|
||||||
|
device=args.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
|
||||||
|
dataset = get_dataset(args)
|
||||||
|
|
||||||
|
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
|
||||||
|
# Putting in some conservative defaults
|
||||||
|
audio_silence_prefix_seconds = info.stt_config.get(
|
||||||
|
"audio_silence_prefix_seconds", 1.0
|
||||||
|
)
|
||||||
|
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
|
||||||
|
|
||||||
|
wer_metric, inference_time, audio_time = run_inference(
|
||||||
|
dataset,
|
||||||
|
mimi,
|
||||||
|
lm_gen,
|
||||||
|
tokenizer,
|
||||||
|
padding_token_id,
|
||||||
|
audio_silence_prefix_seconds,
|
||||||
|
audio_delay_seconds + 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(wer_metric, f"RTF = {audio_time / inference_time:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Example streaming STT inference.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
required=True,
|
||||||
|
choices=DATASET_MAP.keys(),
|
||||||
|
help="Dataset to run inference on.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo", type=str, help="HF repo to load the STT model from."
|
||||||
|
)
|
||||||
|
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--moshi-weight", type=str, help="Path to a local checkpoint file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config-path", type=str, help="Path to a local config file.", default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
help="Batch size.",
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device on which to run, defaults to 'cuda'.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--hf-cache-dir", type=str, help="HuggingFace cache folder.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
100
scripts/stt_from_file_mlx.py
Normal file
100
scripts/stt_from_file_mlx.py
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "huggingface_hub",
|
||||||
|
# "moshi_mlx==0.2.12",
|
||||||
|
# "numpy",
|
||||||
|
# "sentencepiece",
|
||||||
|
# "sounddevice",
|
||||||
|
# "sphn",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import sentencepiece
|
||||||
|
import sphn
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from moshi_mlx import models, utils
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("in_file", help="The file to transcribe.")
|
||||||
|
parser.add_argument("--max-steps", default=4096)
|
||||||
|
parser.add_argument("--hf-repo")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
||||||
|
if args.hf_repo is None:
|
||||||
|
if args.vad:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
else:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||||
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
|
with open(lm_config, "r") as fobj:
|
||||||
|
lm_config = json.load(fobj)
|
||||||
|
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
|
||||||
|
moshi_name = lm_config.get("moshi_name", "model.safetensors")
|
||||||
|
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
|
||||||
|
text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
|
||||||
|
|
||||||
|
lm_config = models.LmConfig.from_config_dict(lm_config)
|
||||||
|
model = models.Lm(lm_config)
|
||||||
|
model.set_dtype(mx.bfloat16)
|
||||||
|
if moshi_weights.endswith(".q4.safetensors"):
|
||||||
|
nn.quantize(model, bits=4, group_size=32)
|
||||||
|
elif moshi_weights.endswith(".q8.safetensors"):
|
||||||
|
nn.quantize(model, bits=8, group_size=64)
|
||||||
|
|
||||||
|
print(f"loading model weights from {moshi_weights}")
|
||||||
|
if args.hf_repo.endswith("-candle"):
|
||||||
|
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
|
||||||
|
else:
|
||||||
|
model.load_weights(moshi_weights, strict=True)
|
||||||
|
|
||||||
|
print(f"loading the text tokenizer from {text_tokenizer}")
|
||||||
|
text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer) # type: ignore
|
||||||
|
|
||||||
|
print(f"loading the audio tokenizer {mimi_weights}")
|
||||||
|
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32))
|
||||||
|
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
|
||||||
|
print("warming up the model")
|
||||||
|
model.warmup()
|
||||||
|
gen = models.LmGen(
|
||||||
|
model=model,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
text_sampler=utils.Sampler(top_k=25, temp=0),
|
||||||
|
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"starting inference {audio.shape}")
|
||||||
|
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
|
||||||
|
last_print_was_vad = False
|
||||||
|
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
|
||||||
|
block = audio[:, None, start_idx : start_idx + 1920]
|
||||||
|
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
|
||||||
|
if args.vad:
|
||||||
|
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||||
|
if vad_heads:
|
||||||
|
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||||
|
if pr_vad > 0.5 and not last_print_was_vad:
|
||||||
|
print(" [end of turn detected]")
|
||||||
|
last_print_was_vad = True
|
||||||
|
else:
|
||||||
|
text_token = gen.step(other_audio_tokens[0])
|
||||||
|
text_token = text_token[0].item()
|
||||||
|
audio_tokens = gen.last_audio_tokens()
|
||||||
|
_text = None
|
||||||
|
if text_token not in (0, 3):
|
||||||
|
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||||
|
_text = _text.replace("▁", " ")
|
||||||
|
print(_text, end="", flush=True)
|
||||||
|
last_print_was_vad = False
|
||||||
|
print()
|
||||||
247
scripts/stt_from_file_pytorch.py
Normal file
247
scripts/stt_from_file_pytorch.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "julius",
|
||||||
|
# "librosa",
|
||||||
|
# "soundfile",
|
||||||
|
# "moshi==0.2.11",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
"""An example script that illustrates how one can get per-word timestamps from
|
||||||
|
Kyutai STT models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
|
||||||
|
import julius
|
||||||
|
import moshi.models
|
||||||
|
import sphn
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TimestampedText:
|
||||||
|
text: str
|
||||||
|
timestamp: tuple[float, float]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})"
|
||||||
|
|
||||||
|
|
||||||
|
def tokens_to_timestamped_text(
|
||||||
|
text_tokens,
|
||||||
|
tokenizer,
|
||||||
|
frame_rate,
|
||||||
|
end_of_padding_id,
|
||||||
|
padding_token_id,
|
||||||
|
offset_seconds,
|
||||||
|
) -> list[TimestampedText]:
|
||||||
|
text_tokens = text_tokens.cpu().view(-1)
|
||||||
|
|
||||||
|
# Normally `end_of_padding` tokens indicate word boundaries.
|
||||||
|
# Everything between them should be a single word;
|
||||||
|
# the time offset of the those tokens correspond to word start and
|
||||||
|
# end timestamps (minus silence prefix and audio delay).
|
||||||
|
#
|
||||||
|
# However, in rare cases some complexities could arise. Firstly,
|
||||||
|
# for words that are said quickly but are represented with
|
||||||
|
# multiple tokens, the boundary might be omitted. Secondly,
|
||||||
|
# for the very last word the end boundary might not happen.
|
||||||
|
# Below is a code snippet that handles those situations a bit
|
||||||
|
# more carefully.
|
||||||
|
|
||||||
|
sequence_timestamps = []
|
||||||
|
|
||||||
|
def _tstmp(start_position, end_position):
|
||||||
|
return (
|
||||||
|
max(0, start_position / frame_rate - offset_seconds),
|
||||||
|
max(0, end_position / frame_rate - offset_seconds),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode(t):
|
||||||
|
t = t[t > padding_token_id]
|
||||||
|
return tokenizer.decode(t.numpy().tolist())
|
||||||
|
|
||||||
|
def _decode_segment(start, end):
|
||||||
|
nonlocal text_tokens
|
||||||
|
nonlocal sequence_timestamps
|
||||||
|
|
||||||
|
text = _decode(text_tokens[start:end])
|
||||||
|
words_inside_segment = text.split()
|
||||||
|
|
||||||
|
if len(words_inside_segment) == 0:
|
||||||
|
return
|
||||||
|
if len(words_inside_segment) == 1:
|
||||||
|
# Single word within the boundaries, the general case
|
||||||
|
sequence_timestamps.append(
|
||||||
|
TimestampedText(text=text, timestamp=_tstmp(start, end))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`.
|
||||||
|
# We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has.
|
||||||
|
for adjacent_word in words_inside_segment[:-1]:
|
||||||
|
n_tokens = len(tokenizer.encode(adjacent_word))
|
||||||
|
sequence_timestamps.append(
|
||||||
|
TimestampedText(
|
||||||
|
text=adjacent_word, timestamp=_tstmp(start, start + n_tokens)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
start += n_tokens
|
||||||
|
|
||||||
|
# The last word takes everything until the boundary
|
||||||
|
adjacent_word = words_inside_segment[-1]
|
||||||
|
sequence_timestamps.append(
|
||||||
|
TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end))
|
||||||
|
)
|
||||||
|
|
||||||
|
(segment_boundaries,) = torch.where(text_tokens == end_of_padding_id)
|
||||||
|
|
||||||
|
if not segment_boundaries.numel():
|
||||||
|
return []
|
||||||
|
|
||||||
|
for i in range(len(segment_boundaries) - 1):
|
||||||
|
segment_start = int(segment_boundaries[i]) + 1
|
||||||
|
segment_end = int(segment_boundaries[i + 1])
|
||||||
|
|
||||||
|
_decode_segment(segment_start, segment_end)
|
||||||
|
|
||||||
|
last_segment_start = segment_boundaries[-1] + 1
|
||||||
|
|
||||||
|
boundary_token = torch.tensor([tokenizer.eos_id()])
|
||||||
|
(end_of_last_segment,) = torch.where(
|
||||||
|
torch.isin(text_tokens[last_segment_start:], boundary_token)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not end_of_last_segment.numel():
|
||||||
|
# upper-bound either end of the audio or 1 second duration, whicher is smaller
|
||||||
|
last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate)
|
||||||
|
else:
|
||||||
|
last_segment_end = last_segment_start + end_of_last_segment[0]
|
||||||
|
_decode_segment(last_segment_start, last_segment_end)
|
||||||
|
|
||||||
|
return sequence_timestamps
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.vad and args.hf_repo is None:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
|
||||||
|
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
||||||
|
args.hf_repo,
|
||||||
|
moshi_weights=args.moshi_weight,
|
||||||
|
mimi_weights=args.mimi_weight,
|
||||||
|
tokenizer=args.tokenizer,
|
||||||
|
config_path=args.config_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
mimi = info.get_mimi(device=args.device)
|
||||||
|
tokenizer = info.get_text_tokenizer()
|
||||||
|
lm = info.get_moshi(
|
||||||
|
device=args.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
|
||||||
|
|
||||||
|
audio_silence_prefix_seconds = info.stt_config.get(
|
||||||
|
"audio_silence_prefix_seconds", 1.0
|
||||||
|
)
|
||||||
|
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
|
||||||
|
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
|
||||||
|
|
||||||
|
audio, input_sample_rate = sphn.read(args.in_file)
|
||||||
|
audio = torch.from_numpy(audio).to(args.device)
|
||||||
|
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
|
||||||
|
if audio.shape[-1] % mimi.frame_size != 0:
|
||||||
|
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
|
||||||
|
audio = torch.nn.functional.pad(audio, (0, to_pad))
|
||||||
|
|
||||||
|
text_tokens_accum = []
|
||||||
|
|
||||||
|
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
|
||||||
|
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
|
||||||
|
silence_chunk = torch.zeros(
|
||||||
|
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = itertools.chain(
|
||||||
|
itertools.repeat(silence_chunk, n_prefix_chunks),
|
||||||
|
torch.split(audio[:, None], mimi.frame_size, dim=-1),
|
||||||
|
itertools.repeat(silence_chunk, n_suffix_chunks),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
nchunks = 0
|
||||||
|
last_print_was_vad = False
|
||||||
|
with mimi.streaming(1), lm_gen.streaming(1):
|
||||||
|
for audio_chunk in chunks:
|
||||||
|
nchunks += 1
|
||||||
|
audio_tokens = mimi.encode(audio_chunk)
|
||||||
|
if args.vad:
|
||||||
|
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
|
||||||
|
if vad_heads:
|
||||||
|
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
|
||||||
|
if pr_vad > 0.5 and not last_print_was_vad:
|
||||||
|
print(" [end of turn detected]")
|
||||||
|
last_print_was_vad = True
|
||||||
|
else:
|
||||||
|
text_tokens = lm_gen.step(audio_tokens)
|
||||||
|
text_token = text_tokens[0, 0, 0].cpu().item()
|
||||||
|
if text_token not in (0, 3):
|
||||||
|
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
|
||||||
|
_text = _text.replace("▁", " ")
|
||||||
|
print(_text, end="", flush=True)
|
||||||
|
last_print_was_vad = False
|
||||||
|
text_tokens_accum.append(text_tokens)
|
||||||
|
|
||||||
|
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
||||||
|
dt = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
|
||||||
|
)
|
||||||
|
timed_text = tokens_to_timestamped_text(
|
||||||
|
utterance_tokens,
|
||||||
|
tokenizer,
|
||||||
|
mimi.frame_rate,
|
||||||
|
end_of_padding_id=0,
|
||||||
|
padding_token_id=padding_token_id,
|
||||||
|
offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = " ".join([str(t) for t in timed_text])
|
||||||
|
print(decoded)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.")
|
||||||
|
parser.add_argument("in_file", help="The file to transcribe.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo", type=str, help="HF repo to load the STT model from. "
|
||||||
|
)
|
||||||
|
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--moshi-weight", type=str, help="Path to a local checkpoint file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config-path", type=str, help="Path to a local config file.", default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device on which to run, defaults to 'cuda'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
135
scripts/stt_from_file_rust_server.py
Normal file
135
scripts/stt_from_file_rust_server.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "msgpack",
|
||||||
|
# "numpy",
|
||||||
|
# "sphn",
|
||||||
|
# "websockets",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import sphn
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
FRAME_SIZE = 1920 # Send data in chunks
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_process_audio(file_path):
|
||||||
|
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
|
||||||
|
pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
|
||||||
|
return pcm_data[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_messages(websocket):
|
||||||
|
transcript = []
|
||||||
|
|
||||||
|
async for message in websocket:
|
||||||
|
data = msgpack.unpackb(message, raw=False)
|
||||||
|
if data["type"] == "Step":
|
||||||
|
# This message contains the signal from the semantic VAD, and tells us how
|
||||||
|
# much audio the server has already processed. We don't use either here.
|
||||||
|
continue
|
||||||
|
if data["type"] == "Word":
|
||||||
|
print(data["text"], end=" ", flush=True)
|
||||||
|
transcript.append(
|
||||||
|
{
|
||||||
|
"text": data["text"],
|
||||||
|
"timestamp": [data["start_time"], data["start_time"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if data["type"] == "EndWord":
|
||||||
|
if len(transcript) > 0:
|
||||||
|
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||||
|
if data["type"] == "Marker":
|
||||||
|
# Received marker, stopping stream
|
||||||
|
break
|
||||||
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
|
async def send_messages(websocket, rtf: float):
|
||||||
|
audio_data = load_and_process_audio(args.in_file)
|
||||||
|
|
||||||
|
async def send_audio(audio: np.ndarray):
|
||||||
|
await websocket.send(
|
||||||
|
msgpack.packb(
|
||||||
|
{"type": "Audio", "pcm": [float(x) for x in audio]},
|
||||||
|
use_single_float=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start with a second of silence.
|
||||||
|
# This is needed for the 2.6B model for technical reasons.
|
||||||
|
await send_audio([0.0] * SAMPLE_RATE)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(0, len(audio_data), FRAME_SIZE):
|
||||||
|
await send_audio(audio_data[i : i + FRAME_SIZE])
|
||||||
|
|
||||||
|
expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time < expected_send_time:
|
||||||
|
await asyncio.sleep(expected_send_time - current_time)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
await send_audio([0.0] * SAMPLE_RATE)
|
||||||
|
|
||||||
|
# Send a marker to indicate the end of the stream.
|
||||||
|
await websocket.send(
|
||||||
|
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll get back the marker once the corresponding audio has been transcribed,
|
||||||
|
# accounting for the delay of the model. That's why we need to send some silence
|
||||||
|
# after the marker, because the model will not return the marker immediately.
|
||||||
|
for _ in range(35):
|
||||||
|
await send_audio([0.0] * SAMPLE_RATE)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_audio(url: str, api_key: str, rtf: float):
|
||||||
|
"""Stream audio data to a WebSocket server."""
|
||||||
|
headers = {"kyutai-api-key": api_key}
|
||||||
|
|
||||||
|
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
|
||||||
|
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||||
|
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
||||||
|
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||||
|
_, transcript = await asyncio.gather(send_task, receive_task)
|
||||||
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("in_file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
help="The url of the server to which to send the audio",
|
||||||
|
default="ws://127.0.0.1:8080",
|
||||||
|
)
|
||||||
|
parser.add_argument("--api-key", default="public_token")
|
||||||
|
parser.add_argument(
|
||||||
|
"--rtf",
|
||||||
|
type=float,
|
||||||
|
default=1.01,
|
||||||
|
help="The real-time factor of how fast to feed in the audio.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = f"{args.url}/api/asr-streaming"
|
||||||
|
transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print()
|
||||||
|
for word in transcript:
|
||||||
|
print(
|
||||||
|
f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}"
|
||||||
|
)
|
||||||
187
scripts/stt_from_file_with_prompt_pytorch.py
Normal file
187
scripts/stt_from_file_with_prompt_pytorch.py
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""An example script that illustrates how one can prompt Kyutai STT models."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import julius
|
||||||
|
import moshi.models
|
||||||
|
import sphn
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class PromptHook:
|
||||||
|
def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.prefix_enforce = deque(self.tokenizer.encode(prefix))
|
||||||
|
self.padding_tokens = padding_tokens
|
||||||
|
|
||||||
|
def on_token(self, token):
|
||||||
|
if not self.prefix_enforce:
|
||||||
|
return
|
||||||
|
|
||||||
|
token = token.item()
|
||||||
|
|
||||||
|
if token in self.padding_tokens:
|
||||||
|
pass
|
||||||
|
elif token == self.prefix_enforce[0]:
|
||||||
|
self.prefix_enforce.popleft()
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
def on_logits(self, logits):
|
||||||
|
if not self.prefix_enforce:
|
||||||
|
return
|
||||||
|
|
||||||
|
mask = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
for t in self.padding_tokens:
|
||||||
|
mask[..., t] = True
|
||||||
|
mask[..., self.prefix_enforce[0]] = True
|
||||||
|
|
||||||
|
logits[:] = torch.where(mask, logits, float("-inf"))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
||||||
|
args.hf_repo,
|
||||||
|
moshi_weights=args.moshi_weight,
|
||||||
|
mimi_weights=args.mimi_weight,
|
||||||
|
tokenizer=args.tokenizer,
|
||||||
|
config_path=args.config_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
mimi = info.get_mimi(device=args.device)
|
||||||
|
tokenizer = info.get_text_tokenizer()
|
||||||
|
lm = info.get_moshi(
|
||||||
|
device=args.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.prompt_text:
|
||||||
|
prompt_hook = PromptHook(tokenizer, args.prompt_text)
|
||||||
|
lm_gen = moshi.models.LMGen(
|
||||||
|
lm,
|
||||||
|
temp=0,
|
||||||
|
temp_text=0.0,
|
||||||
|
on_text_hook=prompt_hook.on_token,
|
||||||
|
on_text_logits_hook=prompt_hook.on_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
|
||||||
|
|
||||||
|
audio_silence_prefix_seconds = info.stt_config.get(
|
||||||
|
"audio_silence_prefix_seconds", 1.0
|
||||||
|
)
|
||||||
|
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
|
||||||
|
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
|
||||||
|
|
||||||
|
def _load_and_process(path):
|
||||||
|
audio, input_sample_rate = sphn.read(path)
|
||||||
|
audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True)
|
||||||
|
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
|
||||||
|
if audio.shape[-1] % mimi.frame_size != 0:
|
||||||
|
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
|
||||||
|
audio = torch.nn.functional.pad(audio, (0, to_pad))
|
||||||
|
return audio
|
||||||
|
|
||||||
|
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
|
||||||
|
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
|
||||||
|
silence_chunk = torch.zeros(
|
||||||
|
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
|
||||||
|
)
|
||||||
|
|
||||||
|
audio = _load_and_process(args.file)
|
||||||
|
if args.prompt_file:
|
||||||
|
audio_prompt = _load_and_process(args.prompt_file)
|
||||||
|
else:
|
||||||
|
audio_prompt = None
|
||||||
|
|
||||||
|
chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
|
||||||
|
|
||||||
|
if audio_prompt is not None:
|
||||||
|
chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1))
|
||||||
|
# adding a bit (0.8s) of silence to separate prompt and the actual audio
|
||||||
|
chain.append(itertools.repeat(silence_chunk, 10))
|
||||||
|
|
||||||
|
chain += [
|
||||||
|
torch.split(audio[:, None, :], mimi.frame_size, dim=-1),
|
||||||
|
itertools.repeat(silence_chunk, n_suffix_chunks),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = itertools.chain(*chain)
|
||||||
|
|
||||||
|
text_tokens_accum = []
|
||||||
|
with mimi.streaming(1), lm_gen.streaming(1):
|
||||||
|
for audio_chunk in tqdm.tqdm(chunks):
|
||||||
|
audio_tokens = mimi.encode(audio_chunk)
|
||||||
|
text_tokens = lm_gen.step(audio_tokens)
|
||||||
|
if text_tokens is not None:
|
||||||
|
text_tokens_accum.append(text_tokens)
|
||||||
|
|
||||||
|
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
|
||||||
|
text_tokens = utterance_tokens.cpu().view(-1)
|
||||||
|
|
||||||
|
# if we have an audio prompt and we don't want to have it in the transcript,
|
||||||
|
# we should cut the corresponding number of frames from the output tokens.
|
||||||
|
# However, there is also some amount of padding that happens before it
|
||||||
|
# due to silence_prefix and audio_delay. Normally it is ignored in detokenization,
|
||||||
|
# but now we should account for it to find the position of the prompt transcript.
|
||||||
|
if args.cut_prompt_transcript and audio_prompt is not None:
|
||||||
|
prompt_frames = audio_prompt.shape[1] // mimi.frame_size
|
||||||
|
no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
|
||||||
|
no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
|
||||||
|
text_tokens = text_tokens[prompt_frames + no_prompt_offset :]
|
||||||
|
|
||||||
|
text = tokenizer.decode(
|
||||||
|
text_tokens[text_tokens > padding_token_id].numpy().tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--file",
|
||||||
|
required=True,
|
||||||
|
help="File to transcribe.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_file",
|
||||||
|
required=False,
|
||||||
|
help="Audio of the prompt.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_text",
|
||||||
|
required=False,
|
||||||
|
help="Text of the prompt.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut-prompt-transcript",
|
||||||
|
action="store_true",
|
||||||
|
help="Cut the prompt from the output transcript",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo", type=str, help="HF repo to load the STT model from. "
|
||||||
|
)
|
||||||
|
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--moshi-weight", type=str, help="Path to a local checkpoint file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config-path", type=str, help="Path to a local config file.", default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device on which to run, defaults to 'cuda'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
116
scripts/stt_from_mic_mlx.py
Normal file
116
scripts/stt_from_mic_mlx.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "huggingface_hub",
|
||||||
|
# "moshi_mlx==0.2.12",
|
||||||
|
# "numpy",
|
||||||
|
# "rustymimi",
|
||||||
|
# "sentencepiece",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import rustymimi
|
||||||
|
import sentencepiece
|
||||||
|
import sounddevice as sd
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from moshi_mlx import models, utils
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--max-steps", default=4096)
|
||||||
|
parser.add_argument("--hf-repo")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.hf_repo is None:
|
||||||
|
if args.vad:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
else:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||||
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
|
with open(lm_config, "r") as fobj:
|
||||||
|
lm_config = json.load(fobj)
|
||||||
|
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
|
||||||
|
moshi_name = lm_config.get("moshi_name", "model.safetensors")
|
||||||
|
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
|
||||||
|
tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
|
||||||
|
|
||||||
|
lm_config = models.LmConfig.from_config_dict(lm_config)
|
||||||
|
model = models.Lm(lm_config)
|
||||||
|
model.set_dtype(mx.bfloat16)
|
||||||
|
if moshi_weights.endswith(".q4.safetensors"):
|
||||||
|
nn.quantize(model, bits=4, group_size=32)
|
||||||
|
elif moshi_weights.endswith(".q8.safetensors"):
|
||||||
|
nn.quantize(model, bits=8, group_size=64)
|
||||||
|
|
||||||
|
print(f"loading model weights from {moshi_weights}")
|
||||||
|
if args.hf_repo.endswith("-candle"):
|
||||||
|
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
|
||||||
|
else:
|
||||||
|
model.load_weights(moshi_weights, strict=True)
|
||||||
|
|
||||||
|
print(f"loading the text tokenizer from {tokenizer}")
|
||||||
|
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore
|
||||||
|
|
||||||
|
print(f"loading the audio tokenizer {mimi_weights}")
|
||||||
|
generated_codebooks = lm_config.generated_codebooks
|
||||||
|
other_codebooks = lm_config.other_codebooks
|
||||||
|
mimi_codebooks = max(generated_codebooks, other_codebooks)
|
||||||
|
audio_tokenizer = rustymimi.Tokenizer(mimi_weights, num_codebooks=mimi_codebooks) # type: ignore
|
||||||
|
print("warming up the model")
|
||||||
|
model.warmup()
|
||||||
|
gen = models.LmGen(
|
||||||
|
model=model,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
text_sampler=utils.Sampler(top_k=25, temp=0),
|
||||||
|
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_queue = queue.Queue()
|
||||||
|
|
||||||
|
def audio_callback(indata, _frames, _time, _status):
|
||||||
|
block_queue.put(indata.copy())
|
||||||
|
|
||||||
|
print("recording audio from microphone, speak to get your words transcribed")
|
||||||
|
last_print_was_vad = False
|
||||||
|
with sd.InputStream(
|
||||||
|
channels=1,
|
||||||
|
dtype="float32",
|
||||||
|
samplerate=24000,
|
||||||
|
blocksize=1920,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
block = block_queue.get()
|
||||||
|
block = block[None, :, 0]
|
||||||
|
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
|
||||||
|
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
||||||
|
:, :, :other_codebooks
|
||||||
|
]
|
||||||
|
if args.vad:
|
||||||
|
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
|
||||||
|
if vad_heads:
|
||||||
|
pr_vad = vad_heads[2][0, 0, 0].item()
|
||||||
|
if pr_vad > 0.5 and not last_print_was_vad:
|
||||||
|
print(" [end of turn detected]")
|
||||||
|
last_print_was_vad = True
|
||||||
|
else:
|
||||||
|
text_token = gen.step(other_audio_tokens[0])
|
||||||
|
text_token = text_token[0].item()
|
||||||
|
audio_tokens = gen.last_audio_tokens()
|
||||||
|
_text = None
|
||||||
|
if text_token not in (0, 3):
|
||||||
|
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||||
|
_text = _text.replace("▁", " ")
|
||||||
|
print(_text, end="", flush=True)
|
||||||
|
last_print_was_vad = False
|
||||||
135
scripts/stt_from_mic_rust_server.py
Normal file
135
scripts/stt_from_mic_rust_server.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "msgpack",
|
||||||
|
# "numpy",
|
||||||
|
# "sounddevice",
|
||||||
|
# "websockets",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import sounddevice as sd
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
|
||||||
|
# The VAD has several prediction heads, each of which tries to determine whether there
|
||||||
|
# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
|
||||||
|
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
|
||||||
|
PAUSE_PREDICTION_HEAD_INDEX = 2
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_messages(websocket, show_vad: bool = False):
|
||||||
|
"""Receive and process messages from the WebSocket server."""
|
||||||
|
try:
|
||||||
|
speech_started = False
|
||||||
|
async for message in websocket:
|
||||||
|
data = msgpack.unpackb(message, raw=False)
|
||||||
|
|
||||||
|
# The Step message only gets sent if the model has semantic VAD available
|
||||||
|
if data["type"] == "Step" and show_vad:
|
||||||
|
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
|
||||||
|
if pause_prediction > 0.5 and speech_started:
|
||||||
|
print("| ", end="", flush=True)
|
||||||
|
speech_started = False
|
||||||
|
|
||||||
|
elif data["type"] == "Word":
|
||||||
|
print(data["text"], end=" ", flush=True)
|
||||||
|
speech_started = True
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
print("Connection closed while receiving messages.")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_messages(websocket, audio_queue):
|
||||||
|
"""Send audio data from microphone to WebSocket server."""
|
||||||
|
try:
|
||||||
|
# Start by draining the queue to avoid lags
|
||||||
|
while not audio_queue.empty():
|
||||||
|
await audio_queue.get()
|
||||||
|
|
||||||
|
print("Starting the transcription")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
audio_data = await audio_queue.get()
|
||||||
|
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
|
||||||
|
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||||
|
await websocket.send(msg)
|
||||||
|
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
print("Connection closed while sending messages.")
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_audio(url: str, api_key: str, show_vad: bool):
|
||||||
|
"""Stream audio data to a WebSocket server."""
|
||||||
|
print("Starting microphone recording...")
|
||||||
|
print("Press Ctrl+C to stop recording")
|
||||||
|
audio_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def audio_callback(indata, frames, time, status):
|
||||||
|
loop.call_soon_threadsafe(
|
||||||
|
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start audio stream
|
||||||
|
with sd.InputStream(
|
||||||
|
samplerate=SAMPLE_RATE,
|
||||||
|
channels=1,
|
||||||
|
dtype="float32",
|
||||||
|
callback=audio_callback,
|
||||||
|
blocksize=1920, # 80ms blocks
|
||||||
|
):
|
||||||
|
headers = {"kyutai-api-key": api_key}
|
||||||
|
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
|
||||||
|
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||||
|
send_task = asyncio.create_task(send_messages(websocket, audio_queue))
|
||||||
|
receive_task = asyncio.create_task(
|
||||||
|
receive_messages(websocket, show_vad=show_vad)
|
||||||
|
)
|
||||||
|
await asyncio.gather(send_task, receive_task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Real-time microphone transcription")
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
help="The URL of the server to which to send the audio",
|
||||||
|
default="ws://127.0.0.1:8080",
|
||||||
|
)
|
||||||
|
parser.add_argument("--api-key", default="public_token")
|
||||||
|
parser.add_argument(
|
||||||
|
"--list-devices", action="store_true", help="List available audio devices"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=int, help="Input device ID (use --list-devices to see options)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-vad",
|
||||||
|
action="store_true",
|
||||||
|
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def handle_sigint(signum, frame):
|
||||||
|
print("Interrupted by user") # Don't complain about KeyboardInterrupt
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
||||||
|
if args.list_devices:
|
||||||
|
print("Available audio devices:")
|
||||||
|
print(sd.query_devices())
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
if args.device is not None:
|
||||||
|
sd.default.device[0] = args.device # Set input device
|
||||||
|
|
||||||
|
url = f"{args.url}/api/asr-streaming"
|
||||||
|
asyncio.run(stream_audio(url, args.api_key, args.show_vad))
|
||||||
206
scripts/tts_mlx.py
Normal file
206
scripts/tts_mlx.py
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "huggingface_hub",
|
||||||
|
# "moshi_mlx==0.2.12",
|
||||||
|
# "numpy",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import sentencepiece
|
||||||
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
|
from moshi_mlx import models
|
||||||
|
from moshi_mlx.client_utils import make_log
|
||||||
|
from moshi_mlx.models.tts import (
|
||||||
|
DEFAULT_DSM_TTS_REPO,
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
TTSModel,
|
||||||
|
)
|
||||||
|
from moshi_mlx.utils.loaders import hf_get
|
||||||
|
|
||||||
|
|
||||||
|
def log(level: str, msg: str):
|
||||||
|
print(make_log(level, msg))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the MLX implementation"
|
||||||
|
)
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=int,
|
||||||
|
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(299792458)
|
||||||
|
|
||||||
|
log("info", "retrieving checkpoints")
|
||||||
|
|
||||||
|
raw_config = hf_get("config.json", args.hf_repo)
|
||||||
|
with open(hf_get(raw_config), "r") as fobj:
|
||||||
|
raw_config = json.load(fobj)
|
||||||
|
|
||||||
|
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
|
||||||
|
moshi_name = raw_config.get("moshi_name", "model.safetensors")
|
||||||
|
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||||
|
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||||
|
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||||
|
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||||
|
# The following line gets around it for now.
|
||||||
|
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||||
|
model = models.Lm(lm_config)
|
||||||
|
model.set_dtype(mx.bfloat16)
|
||||||
|
|
||||||
|
log("info", f"loading model weights from {moshi_weights}")
|
||||||
|
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
|
||||||
|
|
||||||
|
if args.quantize is not None:
|
||||||
|
log("info", f"quantizing model to {args.quantize} bits")
|
||||||
|
nn.quantize(model.depformer, bits=args.quantize)
|
||||||
|
for layer in model.transformer.layers:
|
||||||
|
nn.quantize(layer.self_attn, bits=args.quantize)
|
||||||
|
nn.quantize(layer.gating, bits=args.quantize)
|
||||||
|
|
||||||
|
log("info", f"loading the text tokenizer from {tokenizer}")
|
||||||
|
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
|
||||||
|
|
||||||
|
log("info", f"loading the audio tokenizer {mimi_weights}")
|
||||||
|
generated_codebooks = lm_config.generated_codebooks
|
||||||
|
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
|
||||||
|
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
|
||||||
|
|
||||||
|
cfg_coef_conditioning = None
|
||||||
|
tts_model = TTSModel(
|
||||||
|
model,
|
||||||
|
audio_tokenizer,
|
||||||
|
text_tokenizer,
|
||||||
|
voice_repo=args.voice_repo,
|
||||||
|
temp=0.6,
|
||||||
|
cfg_coef=1,
|
||||||
|
max_padding=8,
|
||||||
|
initial_padding=2,
|
||||||
|
final_padding=2,
|
||||||
|
padding_bonus=0,
|
||||||
|
raw_config=raw_config,
|
||||||
|
)
|
||||||
|
if tts_model.valid_cfg_conditionings:
|
||||||
|
# Model was trained with CFG distillation.
|
||||||
|
cfg_coef_conditioning = tts_model.cfg_coef
|
||||||
|
tts_model.cfg_coef = 1.0
|
||||||
|
cfg_is_no_text = False
|
||||||
|
cfg_is_no_prefix = False
|
||||||
|
else:
|
||||||
|
cfg_is_no_text = True
|
||||||
|
cfg_is_no_prefix = True
|
||||||
|
mimi = tts_model.mimi
|
||||||
|
|
||||||
|
log("info", f"reading input from {args.inp}")
|
||||||
|
if args.inp == "-":
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
text_to_tts = sys.stdin.read().strip()
|
||||||
|
else:
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text_to_tts = fobj.read().strip()
|
||||||
|
|
||||||
|
all_entries = [tts_model.prepare_script([text_to_tts])]
|
||||||
|
if tts_model.multi_speaker:
|
||||||
|
voices = [tts_model.get_voice_path(args.voice)]
|
||||||
|
else:
|
||||||
|
voices = []
|
||||||
|
all_attributes = [
|
||||||
|
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||||
|
]
|
||||||
|
|
||||||
|
wav_frames = queue.Queue()
|
||||||
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
if (frame == -1).any():
|
||||||
|
return
|
||||||
|
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
|
||||||
|
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||||
|
wav_frames.put_nowait(_pcm)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
log("info", "starting the inference loop")
|
||||||
|
begin = time.time()
|
||||||
|
result = tts_model.generate(
|
||||||
|
all_entries,
|
||||||
|
all_attributes,
|
||||||
|
cfg_is_no_prefix=cfg_is_no_prefix,
|
||||||
|
cfg_is_no_text=cfg_is_no_text,
|
||||||
|
on_frame=_on_frame,
|
||||||
|
)
|
||||||
|
frames = mx.concat(result.frames, axis=-1)
|
||||||
|
total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
|
||||||
|
time_taken = time.time() - begin
|
||||||
|
total_speed = total_duration / time_taken
|
||||||
|
log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
|
||||||
|
return result
|
||||||
|
|
||||||
|
if args.out == "-":
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
try:
|
||||||
|
pcm_data = wav_frames.get(block=False)
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
except queue.Empty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=mimi.sample_rate,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
run()
|
||||||
|
time.sleep(3)
|
||||||
|
while True:
|
||||||
|
if wav_frames.qsize() == 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
frames.append(wav_frames.get_nowait())
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
wav = np.concat(frames, -1)
|
||||||
|
sphn.write_wav(args.out, wav, mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
317
scripts/tts_mlx_streaming.py
Normal file
317
scripts/tts_mlx_streaming.py
Normal file
|
|
@ -0,0 +1,317 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "huggingface_hub",
|
||||||
|
# "moshi_mlx==0.2.12",
|
||||||
|
# "numpy",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import sentencepiece
|
||||||
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
|
import typing as tp
|
||||||
|
from moshi_mlx import models
|
||||||
|
from moshi_mlx.models.generate import LmGen
|
||||||
|
from moshi_mlx.client_utils import make_log
|
||||||
|
from moshi_mlx.modules.conditioner import (
|
||||||
|
ConditionAttributes,
|
||||||
|
ConditionTensor,
|
||||||
|
dropout_all_conditions,
|
||||||
|
)
|
||||||
|
from moshi_mlx.utils.sampling import Sampler
|
||||||
|
from moshi_mlx.models.tts import (
|
||||||
|
Entry,
|
||||||
|
DEFAULT_DSM_TTS_REPO,
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
TTSModel,
|
||||||
|
script_to_entries,
|
||||||
|
)
|
||||||
|
from moshi_mlx.utils.loaders import hf_get
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
|
||||||
|
multi_speaker = first_turn and model.multi_speaker
|
||||||
|
return script_to_entries(
|
||||||
|
model.tokenizer,
|
||||||
|
model.machine.token_ids,
|
||||||
|
model.mimi.frame_rate,
|
||||||
|
[script],
|
||||||
|
multi_speaker=multi_speaker,
|
||||||
|
padding_between=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_null(
|
||||||
|
all_attributes: tp.Sequence[ConditionAttributes],
|
||||||
|
) -> list[ConditionAttributes]:
|
||||||
|
# When using CFG, returns the null conditions.
|
||||||
|
return dropout_all_conditions(all_attributes)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSGen:
|
||||||
|
tts_model: TTSModel
|
||||||
|
attributes: tp.Sequence[ConditionAttributes]
|
||||||
|
on_frame: tp.Optional[tp.Callable[[mx.array], None]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
tts_model = self.tts_model
|
||||||
|
attributes = self.attributes
|
||||||
|
self.offset = 0
|
||||||
|
self.state = self.tts_model.machine.new_state([])
|
||||||
|
|
||||||
|
if tts_model.cfg_coef != 1.0:
|
||||||
|
if tts_model.valid_cfg_conditionings:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support direct CFG, but was trained with "
|
||||||
|
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||||
|
)
|
||||||
|
nulled = _make_null(attributes)
|
||||||
|
attributes = list(attributes) + nulled
|
||||||
|
|
||||||
|
assert tts_model.lm.condition_provider is not None
|
||||||
|
self.ct = None
|
||||||
|
self.cross_attention_src = None
|
||||||
|
for _attr in attributes:
|
||||||
|
for _key, _value in _attr.text.items():
|
||||||
|
_ct = tts_model.lm.condition_provider.condition_tensor(_key, _value)
|
||||||
|
if self.ct is None:
|
||||||
|
self.ct = _ct
|
||||||
|
else:
|
||||||
|
self.ct = ConditionTensor(self.ct.tensor + _ct.tensor)
|
||||||
|
for _key, _value in _attr.tensor.items():
|
||||||
|
_conditioner = tts_model.lm.condition_provider.conditioners[_key]
|
||||||
|
_ca_src = _conditioner.condition(_value)
|
||||||
|
if self.cross_attention_src is None:
|
||||||
|
self.cross_attention_src = _ca_src
|
||||||
|
else:
|
||||||
|
raise ValueError("multiple cross-attention conditioners")
|
||||||
|
|
||||||
|
def _on_audio_hook(audio_tokens):
|
||||||
|
delays = tts_model.lm.delays
|
||||||
|
for q in range(audio_tokens.shape[0]):
|
||||||
|
delay = delays[q]
|
||||||
|
if self.offset < delay + tts_model.delay_steps:
|
||||||
|
audio_tokens[q] = tts_model.machine.token_ids.zero
|
||||||
|
|
||||||
|
def _on_text_hook(text_tokens):
|
||||||
|
tokens = text_tokens.tolist()
|
||||||
|
out_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
||||||
|
out_tokens.append(out_token)
|
||||||
|
text_tokens[:] = mx.array(out_tokens, dtype=mx.int64)
|
||||||
|
|
||||||
|
self.lm_gen = LmGen(
|
||||||
|
tts_model.lm,
|
||||||
|
max_steps=tts_model.max_gen_length,
|
||||||
|
text_sampler=Sampler(temp=tts_model.temp),
|
||||||
|
audio_sampler=Sampler(temp=tts_model.temp),
|
||||||
|
cfg_coef=tts_model.cfg_coef,
|
||||||
|
on_text_hook=_on_text_hook,
|
||||||
|
on_audio_hook=_on_audio_hook,
|
||||||
|
# TODO(laurent):
|
||||||
|
# cfg_is_masked_until=cfg_is_masked_until,
|
||||||
|
# cfg_is_no_text=cfg_is_no_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_last(self):
|
||||||
|
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||||
|
self._step()
|
||||||
|
additional_steps = (
|
||||||
|
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
|
||||||
|
)
|
||||||
|
for _ in range(additional_steps):
|
||||||
|
self._step()
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
||||||
|
self._step()
|
||||||
|
|
||||||
|
def _step(self):
|
||||||
|
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||||
|
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||||
|
input_tokens = (
|
||||||
|
mx.ones((1, missing), dtype=mx.int64)
|
||||||
|
* self.tts_model.machine.token_ids.zero
|
||||||
|
)
|
||||||
|
self.lm_gen.step(
|
||||||
|
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
|
||||||
|
)
|
||||||
|
frame = self.lm_gen.last_audio_tokens()
|
||||||
|
self.offset += 1
|
||||||
|
if frame is not None:
|
||||||
|
if self.on_frame is not None:
|
||||||
|
self.on_frame(frame)
|
||||||
|
|
||||||
|
def append_entry(self, entry):
|
||||||
|
self.state.entries.append(entry)
|
||||||
|
|
||||||
|
|
||||||
|
def log(level: str, msg: str):
|
||||||
|
print(make_log(level, msg))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the MLX implementation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=int,
|
||||||
|
help="The quantization to be applied, e.g. 8 for 8 bits.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(299792458)
|
||||||
|
|
||||||
|
log("info", "retrieving checkpoints")
|
||||||
|
|
||||||
|
raw_config = hf_get("config.json", args.hf_repo)
|
||||||
|
with open(hf_get(raw_config), "r") as fobj:
|
||||||
|
raw_config = json.load(fobj)
|
||||||
|
|
||||||
|
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
|
||||||
|
moshi_name = raw_config.get("moshi_name", "model.safetensors")
|
||||||
|
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||||
|
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||||
|
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||||
|
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||||
|
# The following line gets around it for now.
|
||||||
|
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||||
|
model = models.Lm(lm_config)
|
||||||
|
model.set_dtype(mx.bfloat16)
|
||||||
|
|
||||||
|
log("info", f"loading model weights from {moshi_weights}")
|
||||||
|
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
|
||||||
|
|
||||||
|
if args.quantize is not None:
|
||||||
|
log("info", f"quantizing model to {args.quantize} bits")
|
||||||
|
nn.quantize(model.depformer, bits=args.quantize)
|
||||||
|
for layer in model.transformer.layers:
|
||||||
|
nn.quantize(layer.self_attn, bits=args.quantize)
|
||||||
|
nn.quantize(layer.gating, bits=args.quantize)
|
||||||
|
|
||||||
|
log("info", f"loading the text tokenizer from {tokenizer}")
|
||||||
|
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
|
||||||
|
|
||||||
|
log("info", f"loading the audio tokenizer {mimi_weights}")
|
||||||
|
generated_codebooks = lm_config.generated_codebooks
|
||||||
|
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
|
||||||
|
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
|
||||||
|
|
||||||
|
cfg_coef_conditioning = None
|
||||||
|
tts_model = TTSModel(
|
||||||
|
model,
|
||||||
|
audio_tokenizer,
|
||||||
|
text_tokenizer,
|
||||||
|
voice_repo=args.voice_repo,
|
||||||
|
temp=0.6,
|
||||||
|
cfg_coef=1,
|
||||||
|
max_padding=8,
|
||||||
|
initial_padding=2,
|
||||||
|
final_padding=2,
|
||||||
|
padding_bonus=0,
|
||||||
|
raw_config=raw_config,
|
||||||
|
)
|
||||||
|
if tts_model.valid_cfg_conditionings:
|
||||||
|
# Model was trained with CFG distillation.
|
||||||
|
cfg_coef_conditioning = tts_model.cfg_coef
|
||||||
|
tts_model.cfg_coef = 1.0
|
||||||
|
mimi = tts_model.mimi
|
||||||
|
|
||||||
|
log("info", "reading input from stdin")
|
||||||
|
|
||||||
|
if tts_model.multi_speaker:
|
||||||
|
voices = [tts_model.get_voice_path(args.voice)]
|
||||||
|
else:
|
||||||
|
voices = []
|
||||||
|
all_attributes = [
|
||||||
|
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
|
||||||
|
]
|
||||||
|
|
||||||
|
wav_frames = queue.Queue()
|
||||||
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
if (frame == -1).any():
|
||||||
|
return
|
||||||
|
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
|
||||||
|
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
|
||||||
|
wav_frames.put_nowait(_pcm)
|
||||||
|
|
||||||
|
gen = TTSGen(tts_model, all_attributes, on_frame=_on_frame)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
log("info", "starting the inference loop")
|
||||||
|
first_turn = True
|
||||||
|
for line in sys.stdin:
|
||||||
|
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
|
||||||
|
first_turn = False
|
||||||
|
for entry in entries:
|
||||||
|
gen.append_entry(entry)
|
||||||
|
gen.process()
|
||||||
|
gen.process_last()
|
||||||
|
|
||||||
|
if args.out == "-":
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
try:
|
||||||
|
pcm_data = wav_frames.get(block=False)
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
except queue.Empty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=mimi.sample_rate,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
run()
|
||||||
|
while True:
|
||||||
|
if wav_frames.qsize() == 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
frames.append(wav_frames.get_nowait())
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
wav = np.concat(frames, -1)
|
||||||
|
sphn.write_wav(args.out, wav, mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
140
scripts/tts_pytorch.py
Normal file
140
scripts/tts_pytorch.py
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "moshi==0.2.11",
|
||||||
|
# "torch",
|
||||||
|
# "sphn",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import queue
|
||||||
|
import sphn
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
)
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice",
|
||||||
|
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
help="The voice to use, relative to the voice repo root. "
|
||||||
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device on which to run, defaults to 'cuda'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||||
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
|
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.inp == "-":
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
text = sys.stdin.read().strip()
|
||||||
|
else:
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text = fobj.read().strip()
|
||||||
|
|
||||||
|
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
|
||||||
|
entries = tts_model.prepare_script([text], padding_between=1)
|
||||||
|
if args.voice.endswith(".safetensors"):
|
||||||
|
voice_path = args.voice
|
||||||
|
else:
|
||||||
|
voice_path = tts_model.get_voice_path(args.voice)
|
||||||
|
# CFG coef goes here because the model was trained with CFG distillation,
|
||||||
|
# so it's not _actually_ doing CFG at inference time.
|
||||||
|
# Also, if you are generating a dialog, you should have two voices in the list.
|
||||||
|
condition_attributes = tts_model.make_condition_attributes(
|
||||||
|
[voice_path], cfg_coef=2.0
|
||||||
|
)
|
||||||
|
_frames_cnt = 0
|
||||||
|
|
||||||
|
if args.out == "-":
|
||||||
|
# Stream the audio to the speakers using sounddevice.
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
pcms = queue.Queue()
|
||||||
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
nonlocal _frames_cnt
|
||||||
|
if (frame != -1).all():
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
||||||
|
_frames_cnt += 1
|
||||||
|
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
try:
|
||||||
|
pcm_data = pcms.get(block=False)
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
except queue.Empty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=tts_model.mimi.sample_rate,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
with tts_model.mimi.streaming(1):
|
||||||
|
tts_model.generate(
|
||||||
|
[entries], [condition_attributes], on_frame=_on_frame
|
||||||
|
)
|
||||||
|
time.sleep(3)
|
||||||
|
while True:
|
||||||
|
if pcms.qsize() == 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
nonlocal _frames_cnt
|
||||||
|
if (frame != -1).all():
|
||||||
|
_frames_cnt += 1
|
||||||
|
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||||
|
|
||||||
|
result = tts_model.generate(
|
||||||
|
[entries], [condition_attributes], on_frame=_on_frame
|
||||||
|
)
|
||||||
|
with tts_model.mimi.streaming(1), torch.no_grad():
|
||||||
|
pcms = []
|
||||||
|
for frame in result.frames[tts_model.delay_steps :]:
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.append(np.clip(pcm[0, 0], -1, 1))
|
||||||
|
pcm = np.concatenate(pcms, axis=-1)
|
||||||
|
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
261
scripts/tts_pytorch_streaming.py
Normal file
261
scripts/tts_pytorch_streaming.py
Normal file
|
|
@ -0,0 +1,261 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "moshi==0.2.11",
|
||||||
|
# "torch",
|
||||||
|
# "sphn",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import queue
|
||||||
|
import sphn
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import typing as tp
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
from moshi.conditioners import dropout_all_conditions
|
||||||
|
from moshi.models.lm import LMGen
|
||||||
|
from moshi.models.tts import (
|
||||||
|
Entry,
|
||||||
|
DEFAULT_DSM_TTS_REPO,
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
TTSModel,
|
||||||
|
ConditionAttributes,
|
||||||
|
script_to_entries,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
|
||||||
|
multi_speaker = first_turn and model.multi_speaker
|
||||||
|
return script_to_entries(
|
||||||
|
model.tokenizer,
|
||||||
|
model.machine.token_ids,
|
||||||
|
model.mimi.frame_rate,
|
||||||
|
[script],
|
||||||
|
multi_speaker=multi_speaker,
|
||||||
|
padding_between=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_null(
|
||||||
|
all_attributes: tp.Sequence[ConditionAttributes],
|
||||||
|
) -> list[ConditionAttributes]:
|
||||||
|
# When using CFG, returns the null conditions.
|
||||||
|
return dropout_all_conditions(all_attributes)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSGen:
|
||||||
|
tts_model: TTSModel
|
||||||
|
attributes: tp.Sequence[ConditionAttributes]
|
||||||
|
on_frame: tp.Optional[tp.Callable[[torch.Tensor], None]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
tts_model = self.tts_model
|
||||||
|
attributes = self.attributes
|
||||||
|
self.offset = 0
|
||||||
|
self.state = self.tts_model.machine.new_state([])
|
||||||
|
if tts_model.cfg_coef != 1.0:
|
||||||
|
if tts_model.valid_cfg_conditionings:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support direct CFG, but was trained with "
|
||||||
|
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
|
||||||
|
)
|
||||||
|
nulled = _make_null(attributes)
|
||||||
|
attributes = list(attributes) + nulled
|
||||||
|
|
||||||
|
assert tts_model.lm.condition_provider is not None
|
||||||
|
prepared = tts_model.lm.condition_provider.prepare(attributes)
|
||||||
|
condition_tensors = tts_model.lm.condition_provider(prepared)
|
||||||
|
|
||||||
|
def _on_text_logits_hook(text_logits):
|
||||||
|
if tts_model.padding_bonus:
|
||||||
|
text_logits[..., tts_model.machine.token_ids.pad] += (
|
||||||
|
tts_model.padding_bonus
|
||||||
|
)
|
||||||
|
return text_logits
|
||||||
|
|
||||||
|
def _on_audio_hook(audio_tokens):
|
||||||
|
audio_offset = tts_model.lm.audio_offset
|
||||||
|
delays = tts_model.lm.delays
|
||||||
|
for q in range(audio_tokens.shape[1]):
|
||||||
|
delay = delays[q + audio_offset]
|
||||||
|
if self.offset < delay + tts_model.delay_steps:
|
||||||
|
audio_tokens[:, q] = tts_model.machine.token_ids.zero
|
||||||
|
|
||||||
|
def _on_text_hook(text_tokens):
|
||||||
|
tokens = text_tokens.tolist()
|
||||||
|
out_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
|
||||||
|
out_tokens.append(out_token)
|
||||||
|
text_tokens[:] = torch.tensor(
|
||||||
|
out_tokens, dtype=torch.long, device=text_tokens.device
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_model.lm.dep_q = tts_model.n_q
|
||||||
|
self.lm_gen = LMGen(
|
||||||
|
tts_model.lm,
|
||||||
|
temp=tts_model.temp,
|
||||||
|
temp_text=tts_model.temp,
|
||||||
|
cfg_coef=tts_model.cfg_coef,
|
||||||
|
condition_tensors=condition_tensors,
|
||||||
|
on_text_logits_hook=_on_text_logits_hook,
|
||||||
|
on_text_hook=_on_text_hook,
|
||||||
|
on_audio_hook=_on_audio_hook,
|
||||||
|
cfg_is_masked_until=None,
|
||||||
|
cfg_is_no_text=True,
|
||||||
|
)
|
||||||
|
self.lm_gen.streaming_forever(1)
|
||||||
|
|
||||||
|
def process_last(self):
|
||||||
|
while len(self.state.entries) > 0 or self.state.end_step is not None:
|
||||||
|
self._step()
|
||||||
|
additional_steps = (
|
||||||
|
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
|
||||||
|
)
|
||||||
|
for _ in range(additional_steps):
|
||||||
|
self._step()
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
|
||||||
|
self._step()
|
||||||
|
|
||||||
|
def _step(self):
|
||||||
|
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
|
||||||
|
input_tokens = torch.full(
|
||||||
|
(1, missing, 1),
|
||||||
|
self.tts_model.machine.token_ids.zero,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.tts_model.lm.device,
|
||||||
|
)
|
||||||
|
frame = self.lm_gen.step(input_tokens)
|
||||||
|
self.offset += 1
|
||||||
|
if frame is not None:
|
||||||
|
if self.on_frame is not None:
|
||||||
|
self.on_frame(frame)
|
||||||
|
|
||||||
|
def append_entry(self, entry):
|
||||||
|
self.state.entries.append(entry)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice",
|
||||||
|
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
help="The voice to use, relative to the voice repo root. "
|
||||||
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device on which to run, defaults to 'cuda'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||||
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
|
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.voice.endswith(".safetensors"):
|
||||||
|
voice_path = args.voice
|
||||||
|
else:
|
||||||
|
voice_path = tts_model.get_voice_path(args.voice)
|
||||||
|
# CFG coef goes here because the model was trained with CFG distillation,
|
||||||
|
# so it's not _actually_ doing CFG at inference time.
|
||||||
|
# Also, if you are generating a dialog, you should have two voices in the list.
|
||||||
|
condition_attributes = tts_model.make_condition_attributes(
|
||||||
|
[voice_path], cfg_coef=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
|
||||||
|
if args.out == "-":
|
||||||
|
# Stream the audio to the speakers using sounddevice.
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
pcms = queue.Queue()
|
||||||
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
if (frame != -1).all():
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
try:
|
||||||
|
pcm_data = pcms.get(block=False)
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
except queue.Empty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=tts_model.mimi.sample_rate,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
) and tts_model.mimi.streaming(1):
|
||||||
|
first_turn = True
|
||||||
|
for line in sys.stdin:
|
||||||
|
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
|
||||||
|
first_turn = False
|
||||||
|
for entry in entries:
|
||||||
|
gen.append_entry(entry)
|
||||||
|
gen.process()
|
||||||
|
gen.process_last()
|
||||||
|
while True:
|
||||||
|
if pcms.qsize() == 0:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
pcms = []
|
||||||
|
|
||||||
|
def _on_frame(frame: torch.Tensor):
|
||||||
|
if (frame != -1).all():
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.append(np.clip(pcm[0, 0]))
|
||||||
|
|
||||||
|
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
|
||||||
|
with tts_model.mimi.streaming(1):
|
||||||
|
first_turn = True
|
||||||
|
for line in sys.stdin:
|
||||||
|
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
|
||||||
|
first_turn = False
|
||||||
|
for entry in entries:
|
||||||
|
gen.append_entry(entry)
|
||||||
|
gen.process()
|
||||||
|
gen.process_last()
|
||||||
|
pcm = np.concatenate(pcms, axis=-1)
|
||||||
|
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
178
scripts/tts_rust_server.py
Normal file
178
scripts/tts_rust_server.py
Normal file
|
|
@ -0,0 +1,178 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "msgpack",
|
||||||
|
# "numpy",
|
||||||
|
# "sphn",
|
||||||
|
# "websockets",
|
||||||
|
# "sounddevice",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import sounddevice as sd
|
||||||
|
import sphn
|
||||||
|
import tqdm
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
|
||||||
|
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
|
||||||
|
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
|
||||||
|
AUTH_TOKEN = "public_token"
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
|
||||||
|
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
|
||||||
|
accumulated_samples = 0
|
||||||
|
last_seconds = 0
|
||||||
|
|
||||||
|
async for message_bytes in websocket:
|
||||||
|
msg = msgpack.unpackb(message_bytes)
|
||||||
|
|
||||||
|
if msg["type"] == "Audio":
|
||||||
|
pcm = np.array(msg["pcm"]).astype(np.float32)
|
||||||
|
await output_queue.put(pcm)
|
||||||
|
|
||||||
|
accumulated_samples += len(msg["pcm"])
|
||||||
|
current_seconds = accumulated_samples // SAMPLE_RATE
|
||||||
|
if current_seconds > last_seconds:
|
||||||
|
pbar.update(current_seconds - last_seconds)
|
||||||
|
last_seconds = current_seconds
|
||||||
|
|
||||||
|
print("End of audio.")
|
||||||
|
await output_queue.put(None) # Signal end of audio
|
||||||
|
|
||||||
|
|
||||||
|
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
|
||||||
|
if out == "-":
|
||||||
|
should_exit = False
|
||||||
|
|
||||||
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
|
nonlocal should_exit
|
||||||
|
|
||||||
|
try:
|
||||||
|
pcm_data = output_queue.get_nowait()
|
||||||
|
if pcm_data is not None:
|
||||||
|
outdata[:, 0] = pcm_data
|
||||||
|
else:
|
||||||
|
should_exit = True
|
||||||
|
outdata[:] = 0
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
outdata[:] = 0
|
||||||
|
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=SAMPLE_RATE,
|
||||||
|
blocksize=1920,
|
||||||
|
channels=1,
|
||||||
|
callback=audio_callback,
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
if should_exit:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
else:
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
item = await output_queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
frames.append(item)
|
||||||
|
|
||||||
|
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
|
||||||
|
print(f"Saved audio to {out}")
|
||||||
|
|
||||||
|
|
||||||
|
async def read_lines_from_stdin():
|
||||||
|
reader = asyncio.StreamReader()
|
||||||
|
protocol = asyncio.StreamReaderProtocol(reader)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
||||||
|
while True:
|
||||||
|
line = await reader.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
yield line.decode().rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
async def read_lines_from_file(path: str):
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||||
|
|
||||||
|
await asyncio.to_thread(producer)
|
||||||
|
while True:
|
||||||
|
line = await queue.get()
|
||||||
|
if line is None:
|
||||||
|
break
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
async def get_lines(source: str):
|
||||||
|
if source == "-":
|
||||||
|
async for line in read_lines_from_stdin():
|
||||||
|
yield line
|
||||||
|
else:
|
||||||
|
async for line in read_lines_from_file(source):
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_client():
|
||||||
|
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice",
|
||||||
|
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
help="The voice to use, relative to the voice repo root. "
|
||||||
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
help="The URL of the server to which to send the audio",
|
||||||
|
default="ws://127.0.0.1:8080",
|
||||||
|
)
|
||||||
|
parser.add_argument("--api-key", default="public_token")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = {"voice": args.voice, "format": "PcmMessagePack"}
|
||||||
|
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||||
|
print(uri)
|
||||||
|
|
||||||
|
if args.inp == "-":
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
headers = {"kyutai-api-key": args.api_key}
|
||||||
|
|
||||||
|
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||||
|
print("connected")
|
||||||
|
|
||||||
|
async def send_loop():
|
||||||
|
print("go send")
|
||||||
|
async for line in get_lines(args.inp):
|
||||||
|
for word in line.split():
|
||||||
|
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
|
||||||
|
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||||
|
|
||||||
|
output_queue = asyncio.Queue()
|
||||||
|
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||||
|
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||||
|
send_task = asyncio.create_task(send_loop())
|
||||||
|
await asyncio.gather(receive_task, output_audio_task, send_task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(websocket_client())
|
||||||
633
stt-rs/Cargo.lock
generated
633
stt-rs/Cargo.lock
generated
|
|
@ -97,6 +97,12 @@ version = "0.7.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atomic-waker"
|
||||||
|
version = "1.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "audiopus_sys"
|
name = "audiopus_sys"
|
||||||
version = "0.2.2"
|
version = "0.2.2"
|
||||||
|
|
@ -233,7 +239,7 @@ dependencies = [
|
||||||
"metal 0.27.0",
|
"metal 0.27.0",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
"rand 0.9.1",
|
"rand",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
"rayon",
|
"rayon",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
|
|
@ -295,7 +301,7 @@ dependencies = [
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"fancy-regex",
|
"fancy-regex",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand 0.9.1",
|
"rand",
|
||||||
"rayon",
|
"rayon",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
@ -476,23 +482,23 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs"
|
name = "dirs"
|
||||||
version = "5.0.1"
|
version = "6.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
|
checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dirs-sys",
|
"dirs-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-sys"
|
name = "dirs-sys"
|
||||||
version = "0.4.1"
|
version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
|
checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"option-ext",
|
"option-ext",
|
||||||
"redox_users",
|
"redox_users",
|
||||||
"windows-sys 0.48.0",
|
"windows-sys 0.60.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -607,6 +613,12 @@ dependencies = [
|
||||||
"miniz_oxide",
|
"miniz_oxide",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fnv"
|
||||||
|
version = "1.0.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "foreign-types"
|
name = "foreign-types"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
|
|
@ -658,12 +670,48 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-channel"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-core"
|
name = "futures-core"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-executor"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-io"
|
name = "futures-io"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
|
|
@ -699,9 +747,13 @@ version = "0.3.31"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
"futures-macro",
|
"futures-macro",
|
||||||
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
"memchr",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"slab",
|
"slab",
|
||||||
|
|
@ -979,6 +1031,25 @@ version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "h2"
|
||||||
|
version = "0.4.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5"
|
||||||
|
dependencies = [
|
||||||
|
"atomic-waker",
|
||||||
|
"bytes",
|
||||||
|
"fnv",
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"http",
|
||||||
|
"indexmap",
|
||||||
|
"slab",
|
||||||
|
"tokio",
|
||||||
|
"tokio-util 0.7.15",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "half"
|
name = "half"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
|
|
@ -989,7 +1060,7 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"crunchy",
|
"crunchy",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand 0.9.1",
|
"rand",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -1013,19 +1084,144 @@ checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hf-hub"
|
name = "hf-hub"
|
||||||
version = "0.3.2"
|
version = "0.4.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
|
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dirs",
|
"dirs",
|
||||||
|
"futures",
|
||||||
|
"http",
|
||||||
"indicatif",
|
"indicatif",
|
||||||
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"native-tls",
|
"native-tls",
|
||||||
"rand 0.8.5",
|
"num_cpus",
|
||||||
|
"rand",
|
||||||
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"thiserror 1.0.69",
|
"thiserror 2.0.12",
|
||||||
|
"tokio",
|
||||||
"ureq",
|
"ureq",
|
||||||
|
"windows-sys 0.60.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "http"
|
||||||
|
version = "1.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"fnv",
|
||||||
|
"itoa",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "http-body"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"http",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "http-body-util"
|
||||||
|
version = "0.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "httparse"
|
||||||
|
version = "1.10.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hyper"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-channel",
|
||||||
|
"futures-util",
|
||||||
|
"h2",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"httparse",
|
||||||
|
"itoa",
|
||||||
|
"pin-project-lite",
|
||||||
|
"smallvec",
|
||||||
|
"tokio",
|
||||||
|
"want",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hyper-rustls"
|
||||||
|
version = "0.27.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
|
||||||
|
dependencies = [
|
||||||
|
"http",
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"rustls",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"tokio",
|
||||||
|
"tokio-rustls",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hyper-tls"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"native-tls",
|
||||||
|
"tokio",
|
||||||
|
"tokio-native-tls",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hyper-util"
|
||||||
|
version = "0.1.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb"
|
||||||
|
dependencies = [
|
||||||
|
"base64",
|
||||||
|
"bytes",
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"hyper",
|
||||||
|
"ipnet",
|
||||||
|
"libc",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"socket2",
|
||||||
|
"system-configuration",
|
||||||
|
"tokio",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
"windows-registry",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1158,6 +1354,22 @@ dependencies = [
|
||||||
"web-time",
|
"web-time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ipnet"
|
||||||
|
version = "2.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iri-string"
|
||||||
|
version = "0.7.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "is_terminal_polyfill"
|
name = "is_terminal_polyfill"
|
||||||
version = "1.70.1"
|
version = "1.70.1"
|
||||||
|
|
@ -1345,6 +1557,12 @@ dependencies = [
|
||||||
"paste",
|
"paste",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mime"
|
||||||
|
version = "0.3.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "miniz_oxide"
|
name = "miniz_oxide"
|
||||||
version = "0.8.9"
|
version = "0.8.9"
|
||||||
|
|
@ -1559,7 +1777,7 @@ dependencies = [
|
||||||
"futures-io",
|
"futures-io",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-util",
|
"tokio-util 0.6.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1822,35 +2040,14 @@ version = "5.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rand"
|
|
||||||
version = "0.8.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
|
||||||
dependencies = [
|
|
||||||
"libc",
|
|
||||||
"rand_chacha 0.3.1",
|
|
||||||
"rand_core 0.6.4",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand"
|
name = "rand"
|
||||||
version = "0.9.1"
|
version = "0.9.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rand_chacha 0.9.0",
|
"rand_chacha",
|
||||||
"rand_core 0.9.3",
|
"rand_core",
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rand_chacha"
|
|
||||||
version = "0.3.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
|
||||||
dependencies = [
|
|
||||||
"ppv-lite86",
|
|
||||||
"rand_core 0.6.4",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1860,16 +2057,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ppv-lite86",
|
"ppv-lite86",
|
||||||
"rand_core 0.9.3",
|
"rand_core",
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rand_core"
|
|
||||||
version = "0.6.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
|
||||||
dependencies = [
|
|
||||||
"getrandom 0.2.16",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1888,7 +2076,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
|
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand 0.9.1",
|
"rand",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1955,13 +2143,13 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_users"
|
name = "redox_users"
|
||||||
version = "0.4.6"
|
version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
|
checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom 0.2.16",
|
"getrandom 0.2.16",
|
||||||
"libredox",
|
"libredox",
|
||||||
"thiserror 1.0.69",
|
"thiserror 2.0.12",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1993,6 +2181,49 @@ version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "reqwest"
|
||||||
|
version = "0.12.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "eabf4c97d9130e2bf606614eb937e86edac8292eaa6f422f995d7e8de1eb1813"
|
||||||
|
dependencies = [
|
||||||
|
"base64",
|
||||||
|
"bytes",
|
||||||
|
"encoding_rs",
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"h2",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper",
|
||||||
|
"hyper-rustls",
|
||||||
|
"hyper-tls",
|
||||||
|
"hyper-util",
|
||||||
|
"js-sys",
|
||||||
|
"log",
|
||||||
|
"mime",
|
||||||
|
"native-tls",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tokio-native-tls",
|
||||||
|
"tokio-util 0.7.15",
|
||||||
|
"tower",
|
||||||
|
"tower-http",
|
||||||
|
"tower-service",
|
||||||
|
"url",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"wasm-bindgen-futures",
|
||||||
|
"wasm-streams",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ring"
|
name = "ring"
|
||||||
version = "0.17.14"
|
version = "0.17.14"
|
||||||
|
|
@ -2087,6 +2318,12 @@ dependencies = [
|
||||||
"untrusted",
|
"untrusted",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustversion"
|
||||||
|
version = "1.0.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.20"
|
version = "1.0.20"
|
||||||
|
|
@ -2223,6 +2460,18 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_urlencoded"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd"
|
||||||
|
dependencies = [
|
||||||
|
"form_urlencoded",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
|
@ -2260,6 +2509,17 @@ dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "socks"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stable_deref_trait"
|
name = "stable_deref_trait"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
|
@ -2501,6 +2761,15 @@ dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sync_wrapper"
|
||||||
|
version = "1.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "synstructure"
|
name = "synstructure"
|
||||||
version = "0.13.2"
|
version = "0.13.2"
|
||||||
|
|
@ -2540,6 +2809,27 @@ dependencies = [
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "system-configuration"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.1",
|
||||||
|
"core-foundation",
|
||||||
|
"system-configuration-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "system-configuration-sys"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
|
||||||
|
dependencies = [
|
||||||
|
"core-foundation-sys",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "tempfile"
|
||||||
version = "3.20.0"
|
version = "3.20.0"
|
||||||
|
|
@ -2632,6 +2922,26 @@ dependencies = [
|
||||||
"syn 2.0.103",
|
"syn 2.0.103",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-native-tls"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
|
||||||
|
dependencies = [
|
||||||
|
"native-tls",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-rustls"
|
||||||
|
version = "0.26.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
|
||||||
|
dependencies = [
|
||||||
|
"rustls",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.6.10"
|
version = "0.6.10"
|
||||||
|
|
@ -2647,6 +2957,19 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-util"
|
||||||
|
version = "0.7.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml_datetime"
|
name = "toml_datetime"
|
||||||
version = "0.6.11"
|
version = "0.6.11"
|
||||||
|
|
@ -2664,6 +2987,51 @@ dependencies = [
|
||||||
"winnow",
|
"winnow",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"pin-project-lite",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-http"
|
||||||
|
version = "0.6.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.1",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"iri-string",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-layer"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-service"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing"
|
name = "tracing"
|
||||||
version = "0.1.41"
|
version = "0.1.41"
|
||||||
|
|
@ -2705,6 +3073,12 @@ dependencies = [
|
||||||
"strength_reduce",
|
"strength_reduce",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "try-lock"
|
||||||
|
version = "0.2.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ug"
|
name = "ug"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
|
@ -2786,6 +3160,7 @@ dependencies = [
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"socks",
|
||||||
"url",
|
"url",
|
||||||
"webpki-roots 0.26.11",
|
"webpki-roots 0.26.11",
|
||||||
]
|
]
|
||||||
|
|
@ -2835,6 +3210,15 @@ dependencies = [
|
||||||
"winapi-util",
|
"winapi-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "want"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
|
||||||
|
dependencies = [
|
||||||
|
"try-lock",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.11.1+wasi-snapshot-preview1"
|
version = "0.11.1+wasi-snapshot-preview1"
|
||||||
|
|
@ -2858,6 +3242,7 @@ checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"rustversion",
|
||||||
"wasm-bindgen-macro",
|
"wasm-bindgen-macro",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -2875,6 +3260,19 @@ dependencies = [
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-futures"
|
||||||
|
version = "0.4.50"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"js-sys",
|
||||||
|
"once_cell",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasm-bindgen-macro"
|
name = "wasm-bindgen-macro"
|
||||||
version = "0.2.100"
|
version = "0.2.100"
|
||||||
|
|
@ -2907,6 +3305,29 @@ dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-streams"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"wasm-bindgen-futures",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "web-sys"
|
||||||
|
version = "0.3.77"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
|
||||||
|
dependencies = [
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "web-time"
|
name = "web-time"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
|
@ -2935,6 +3356,22 @@ dependencies = [
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-i686-pc-windows-gnu",
|
||||||
|
"winapi-x86_64-pc-windows-gnu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-i686-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winapi-util"
|
name = "winapi-util"
|
||||||
version = "0.1.9"
|
version = "0.1.9"
|
||||||
|
|
@ -2945,12 +3382,44 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
version = "0.48.0"
|
version = "0.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
|
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-link"
|
||||||
|
version = "0.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-registry"
|
||||||
|
version = "0.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets 0.48.5",
|
"windows-link",
|
||||||
|
"windows-result",
|
||||||
|
"windows-strings",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-result"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
|
||||||
|
dependencies = [
|
||||||
|
"windows-link",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-strings"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
|
||||||
|
dependencies = [
|
||||||
|
"windows-link",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -2972,18 +3441,12 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-targets"
|
name = "windows-sys"
|
||||||
version = "0.48.5"
|
version = "0.60.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
|
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows_aarch64_gnullvm 0.48.5",
|
"windows-targets 0.53.2",
|
||||||
"windows_aarch64_msvc 0.48.5",
|
|
||||||
"windows_i686_gnu 0.48.5",
|
|
||||||
"windows_i686_msvc 0.48.5",
|
|
||||||
"windows_x86_64_gnu 0.48.5",
|
|
||||||
"windows_x86_64_gnullvm 0.48.5",
|
|
||||||
"windows_x86_64_msvc 0.48.5",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3018,12 +3481,6 @@ dependencies = [
|
||||||
"windows_x86_64_msvc 0.53.0",
|
"windows_x86_64_msvc 0.53.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_aarch64_gnullvm"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_gnullvm"
|
name = "windows_aarch64_gnullvm"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
@ -3036,12 +3493,6 @@ version = "0.53.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_aarch64_msvc"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
@ -3054,12 +3505,6 @@ version = "0.53.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_i686_gnu"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
@ -3084,12 +3529,6 @@ version = "0.53.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_i686_msvc"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
@ -3102,12 +3541,6 @@ version = "0.53.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_x86_64_gnu"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
@ -3120,12 +3553,6 @@ version = "0.53.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_x86_64_gnullvm"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
@ -3138,12 +3565,6 @@ version = "0.53.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows_x86_64_msvc"
|
|
||||||
version = "0.48.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.52.6"
|
version = "0.52.6"
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ anyhow = "1.0"
|
||||||
candle = { version = "0.9.1", package = "candle-core" }
|
candle = { version = "0.9.1", package = "candle-core" }
|
||||||
candle-nn = "0.9.1"
|
candle-nn = "0.9.1"
|
||||||
clap = { version = "4.4.12", features = ["derive"] }
|
clap = { version = "4.4.12", features = ["derive"] }
|
||||||
hf-hub = "0.3.2"
|
hf-hub = "0.4.3"
|
||||||
kaudio = "0.2.1"
|
kaudio = "0.2.1"
|
||||||
moshi = "0.6.1"
|
moshi = "0.6.1"
|
||||||
sentencepiece = "0.11.3"
|
sentencepiece = "0.11.3"
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,14 @@ struct Args {
|
||||||
/// Run the model on cpu.
|
/// Run the model on cpu.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Display word level timestamps.
|
||||||
|
#[arg(long)]
|
||||||
|
timestamps: bool,
|
||||||
|
|
||||||
|
/// Display the level of voice activity detection (VAD).
|
||||||
|
#[arg(long)]
|
||||||
|
vad: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device(cpu: bool) -> Result<Device> {
|
fn device(cpu: bool) -> Result<Device> {
|
||||||
|
|
@ -32,6 +40,12 @@ fn device(cpu: bool) -> Result<Device> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct SttConfig {
|
||||||
|
audio_silence_prefix_seconds: f64,
|
||||||
|
audio_delay_seconds: f64,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
mimi_name: String,
|
mimi_name: String,
|
||||||
|
|
@ -45,10 +59,11 @@ struct Config {
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
num_layers: usize,
|
num_layers: usize,
|
||||||
causal: bool,
|
causal: bool,
|
||||||
|
stt_config: SttConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn model_config(&self) -> moshi::lm::Config {
|
fn model_config(&self, vad: bool) -> moshi::lm::Config {
|
||||||
let lm_cfg = moshi::transformer::Config {
|
let lm_cfg = moshi::transformer::Config {
|
||||||
d_model: self.dim,
|
d_model: self.dim,
|
||||||
num_heads: self.num_heads,
|
num_heads: self.num_heads,
|
||||||
|
|
@ -73,6 +88,14 @@ impl Config {
|
||||||
max_seq_len: 4096 * 4,
|
max_seq_len: 4096 * 4,
|
||||||
shared_cross_attn: false,
|
shared_cross_attn: false,
|
||||||
};
|
};
|
||||||
|
let extra_heads = if vad {
|
||||||
|
Some(moshi::lm::ExtraHeadsConfig {
|
||||||
|
num_heads: 4,
|
||||||
|
dim: 6,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
moshi::lm::Config {
|
moshi::lm::Config {
|
||||||
transformer: lm_cfg,
|
transformer: lm_cfg,
|
||||||
depformer: None,
|
depformer: None,
|
||||||
|
|
@ -81,7 +104,7 @@ impl Config {
|
||||||
text_out_vocab_size: self.text_card,
|
text_out_vocab_size: self.text_card,
|
||||||
audio_codebooks: self.n_q,
|
audio_codebooks: self.n_q,
|
||||||
conditioners: Default::default(),
|
conditioners: Default::default(),
|
||||||
extra_heads: None,
|
extra_heads,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -89,16 +112,19 @@ impl Config {
|
||||||
struct Model {
|
struct Model {
|
||||||
state: moshi::asr::State,
|
state: moshi::asr::State,
|
||||||
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
text_tokenizer: sentencepiece::SentencePieceProcessor,
|
||||||
|
timestamps: bool,
|
||||||
|
vad: bool,
|
||||||
|
config: Config,
|
||||||
dev: Device,
|
dev: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn load_from_hf(hf_repo: &str, dev: &Device) -> Result<Self> {
|
fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
|
||||||
let dtype = dev.bf16_default_to_f32();
|
let dtype = dev.bf16_default_to_f32();
|
||||||
|
|
||||||
// Retrieve the model files from the Hugging Face Hub
|
// Retrieve the model files from the Hugging Face Hub
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = api.model(hf_repo.to_string());
|
let repo = api.model(args.hf_repo.to_string());
|
||||||
let config_file = repo.get("config.json")?;
|
let config_file = repo.get("config.json")?;
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
||||||
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
||||||
|
|
@ -110,53 +136,86 @@ impl Model {
|
||||||
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
||||||
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
||||||
let lm = moshi::lm::LmModel::new(
|
let lm = moshi::lm::LmModel::new(
|
||||||
&config.model_config(),
|
&config.model_config(args.vad),
|
||||||
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||||
)?;
|
)?;
|
||||||
let state = moshi::asr::State::new(1, 0, 0., audio_tokenizer, lm)?;
|
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
||||||
|
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
||||||
Ok(Model {
|
Ok(Model {
|
||||||
state,
|
state,
|
||||||
|
config,
|
||||||
text_tokenizer,
|
text_tokenizer,
|
||||||
|
timestamps: args.timestamps,
|
||||||
|
vad: args.vad,
|
||||||
dev: dev.clone(),
|
dev: dev.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(&mut self, pcm: &[f32]) -> Result<()> {
|
fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
|
// Add the silence prefix to the audio.
|
||||||
|
if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
|
||||||
|
let silence_len =
|
||||||
|
(self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
|
||||||
|
pcm.splice(0..0, vec![0.0; silence_len]);
|
||||||
|
}
|
||||||
|
// Add some silence at the end to ensure all the audio is processed.
|
||||||
|
let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
|
||||||
|
pcm.resize(pcm.len() + suffix + 24000, 0.0);
|
||||||
|
|
||||||
|
let mut last_word = None;
|
||||||
|
let mut printed_eot = false;
|
||||||
for pcm in pcm.chunks(1920) {
|
for pcm in pcm.chunks(1920) {
|
||||||
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
|
||||||
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
|
||||||
let mut prev_text_token = 0;
|
|
||||||
for asr_msg in asr_msgs.iter() {
|
for asr_msg in asr_msgs.iter() {
|
||||||
match asr_msg {
|
match asr_msg {
|
||||||
moshi::asr::AsrMsg::Step { .. } | moshi::asr::AsrMsg::EndWord { .. } => {}
|
moshi::asr::AsrMsg::Step { prs, .. } => {
|
||||||
moshi::asr::AsrMsg::Word { tokens, .. } => {
|
// prs is the probability of having no voice activity for different time
|
||||||
for &text_token in tokens.iter() {
|
// horizons.
|
||||||
let s = {
|
// In kyutai/stt-1b-en_fr-candle, these horizons are 0.5s, 1s, 2s, and 3s.
|
||||||
let prev_ids =
|
if self.vad && prs[2][0] > 0.5 && !printed_eot {
|
||||||
self.text_tokenizer.decode_piece_ids(&[prev_text_token]);
|
printed_eot = true;
|
||||||
let ids = self
|
if !self.timestamps {
|
||||||
.text_tokenizer
|
print!(" <endofturn pr={}>", prs[2][0]);
|
||||||
.decode_piece_ids(&[prev_text_token, text_token]);
|
|
||||||
prev_text_token = text_token;
|
|
||||||
prev_ids.and_then(|prev_ids| {
|
|
||||||
ids.map(|ids| {
|
|
||||||
if ids.len() > prev_ids.len() {
|
|
||||||
ids[prev_ids.len()..].to_string()
|
|
||||||
} else {
|
} else {
|
||||||
String::new()
|
println!("<endofturn pr={}>", prs[2][0]);
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
})?
|
}
|
||||||
};
|
moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
|
||||||
print!("{s}");
|
printed_eot = false;
|
||||||
|
if self.timestamps {
|
||||||
|
if let Some((word, start_time)) = last_word.take() {
|
||||||
|
println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
moshi::asr::AsrMsg::Word {
|
||||||
|
tokens, start_time, ..
|
||||||
|
} => {
|
||||||
|
printed_eot = false;
|
||||||
|
let word = self
|
||||||
|
.text_tokenizer
|
||||||
|
.decode_piece_ids(tokens)
|
||||||
|
.unwrap_or_else(|_| String::new());
|
||||||
|
if !self.timestamps {
|
||||||
|
print!(" {word}");
|
||||||
std::io::stdout().flush()?
|
std::io::stdout().flush()?
|
||||||
|
} else {
|
||||||
|
if let Some((word, prev_start_time)) = last_word.take() {
|
||||||
|
println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
|
||||||
|
}
|
||||||
|
last_word = Some((word, *start_time));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some((word, start_time)) = last_word.take() {
|
||||||
|
println!("[{start_time:5.2}- ] {word}");
|
||||||
|
}
|
||||||
println!();
|
println!();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -168,17 +227,15 @@ fn main() -> Result<()> {
|
||||||
println!("Using device: {:?}", device);
|
println!("Using device: {:?}", device);
|
||||||
|
|
||||||
println!("Loading audio file from: {}", args.in_file);
|
println!("Loading audio file from: {}", args.in_file);
|
||||||
let (pcm, sample_rate) = kaudio::pcm_decode(args.in_file)?;
|
let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?;
|
||||||
let mut pcm = if sample_rate != 24_000 {
|
let pcm = if sample_rate != 24_000 {
|
||||||
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
kaudio::resample(&pcm, sample_rate as usize, 24_000)?
|
||||||
} else {
|
} else {
|
||||||
pcm
|
pcm
|
||||||
};
|
};
|
||||||
// Add some silence at the end to ensure all the audio is processed.
|
|
||||||
pcm.resize(pcm.len() + 1920 * 32, 0.0);
|
|
||||||
println!("Loading model from repository: {}", args.hf_repo);
|
println!("Loading model from repository: {}", args.hf_repo);
|
||||||
let mut model = Model::load_from_hf(&args.hf_repo, &device)?;
|
let mut model = Model::load_from_hf(&args, &device)?;
|
||||||
println!("Running inference");
|
println!("Running inference");
|
||||||
model.run(&pcm)?;
|
model.run(pcm)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
238
stt_pytorch.ipynb
Normal file
238
stt_pytorch.ipynb
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "gJEMjPgeI-rw",
|
||||||
|
"outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install moshi"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "CA4K5iDFJcqJ",
|
||||||
|
"outputId": "b609843a-a193-4729-b099-5f8780532333"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "VA3Haix3IZ8Q"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from dataclasses import dataclass\n",
|
||||||
|
"import time\n",
|
||||||
|
"import sentencepiece\n",
|
||||||
|
"import sphn\n",
|
||||||
|
"import textwrap\n",
|
||||||
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
|
"from moshi.models import loaders, MimiModel, LMModel, LMGen"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "9AK5zBMTI9bw"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@dataclass\n",
|
||||||
|
"class InferenceState:\n",
|
||||||
|
" mimi: MimiModel\n",
|
||||||
|
" text_tokenizer: sentencepiece.SentencePieceProcessor\n",
|
||||||
|
" lm_gen: LMGen\n",
|
||||||
|
"\n",
|
||||||
|
" def __init__(\n",
|
||||||
|
" self,\n",
|
||||||
|
" mimi: MimiModel,\n",
|
||||||
|
" text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
|
||||||
|
" lm: LMModel,\n",
|
||||||
|
" batch_size: int,\n",
|
||||||
|
" device: str | torch.device,\n",
|
||||||
|
" ):\n",
|
||||||
|
" self.mimi = mimi\n",
|
||||||
|
" self.text_tokenizer = text_tokenizer\n",
|
||||||
|
" self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
|
||||||
|
" self.device = device\n",
|
||||||
|
" self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
|
||||||
|
" self.batch_size = batch_size\n",
|
||||||
|
" self.mimi.streaming_forever(batch_size)\n",
|
||||||
|
" self.lm_gen.streaming_forever(batch_size)\n",
|
||||||
|
"\n",
|
||||||
|
" def run(self, in_pcms: torch.Tensor):\n",
|
||||||
|
" device = self.lm_gen.lm_model.device\n",
|
||||||
|
" ntokens = 0\n",
|
||||||
|
" first_frame = True\n",
|
||||||
|
" chunks = [\n",
|
||||||
|
" c\n",
|
||||||
|
" for c in in_pcms.split(self.frame_size, dim=2)\n",
|
||||||
|
" if c.shape[-1] == self.frame_size\n",
|
||||||
|
" ]\n",
|
||||||
|
" start_time = time.time()\n",
|
||||||
|
" all_text = []\n",
|
||||||
|
" for chunk in chunks:\n",
|
||||||
|
" codes = self.mimi.encode(chunk)\n",
|
||||||
|
" if first_frame:\n",
|
||||||
|
" # Ensure that the first slice of codes is properly seen by the transformer\n",
|
||||||
|
" # as otherwise the first slice is replaced by the initial tokens.\n",
|
||||||
|
" tokens = self.lm_gen.step(codes)\n",
|
||||||
|
" first_frame = False\n",
|
||||||
|
" tokens = self.lm_gen.step(codes)\n",
|
||||||
|
" if tokens is None:\n",
|
||||||
|
" continue\n",
|
||||||
|
" assert tokens.shape[1] == 1\n",
|
||||||
|
" one_text = tokens[0, 0].cpu()\n",
|
||||||
|
" if one_text.item() not in [0, 3]:\n",
|
||||||
|
" text = self.text_tokenizer.id_to_piece(one_text.item())\n",
|
||||||
|
" text = text.replace(\"▁\", \" \")\n",
|
||||||
|
" all_text.append(text)\n",
|
||||||
|
" ntokens += 1\n",
|
||||||
|
" dt = time.time() - start_time\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
|
||||||
|
" )\n",
|
||||||
|
" return \"\".join(all_text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/",
|
||||||
|
"height": 353,
|
||||||
|
"referenced_widgets": [
|
||||||
|
"0a5f6f887e2b4cd1990a0e9ec0153ed9",
|
||||||
|
"f7893826fcba4bdc87539589d669249b",
|
||||||
|
"8805afb12c484781be85082ff02dad13",
|
||||||
|
"97679c0d9ab44bed9a3456f2fcb541fd",
|
||||||
|
"d73c0321bed54a52b5e1da0a7788e32a",
|
||||||
|
"d67be13a920d4fc89e5570b5b29fc1d2",
|
||||||
|
"6b377c2d7bf945fb89e46c39d246a332",
|
||||||
|
"b82ff365c78e41ad8094b46daf79449d",
|
||||||
|
"477aa7fa82dc42d5bce6f1743c45d626",
|
||||||
|
"cbd288510c474430beb66f346f382c45",
|
||||||
|
"aafc347cdf28428ea6a7abe5b46b726f",
|
||||||
|
"fca09acd5d0d45468c8b04bfb2de7646",
|
||||||
|
"79e35214b51b4a9e9b3f7144b0b34f7b",
|
||||||
|
"89e9a37f69904bd48b954d627bff6687",
|
||||||
|
"57028789c78248a7b0ad4f031c9545c9",
|
||||||
|
"1150fcb427994c2984d4d0f4e4745fe5",
|
||||||
|
"e24b1fc52f294f849019c9b3befb613f",
|
||||||
|
"8724878682cf4c3ca992667c45009398",
|
||||||
|
"36a22c977d5242008871310133b7d2af",
|
||||||
|
"5b3683cad5cb4877b43fadd003edf97f",
|
||||||
|
"703f98272e4d469d8f27f5a465715dd8",
|
||||||
|
"9dbe02ef5fac41cfaee3d02946e65c88",
|
||||||
|
"37faa87ad03a4271992c21ce6a629e18",
|
||||||
|
"570c547e48cd421b814b2c5e028e4c0b",
|
||||||
|
"b173768580fc4c0a8e3abf272e4c363a",
|
||||||
|
"e57d1620f0a9427b85d8b4885ef4e8e3",
|
||||||
|
"5dd4474df70743498b616608182714dd",
|
||||||
|
"cc907676a65f4ad1bf68a77b4a00e89b",
|
||||||
|
"a34abc3b118e4305951a466919c28ff6",
|
||||||
|
"a77ccfcdb90146c7a63b4b2d232bc494",
|
||||||
|
"f7313e6e3a27475993cab3961d6ae363",
|
||||||
|
"39b47fad9c554839868fe9e4bbf7def2",
|
||||||
|
"14e9511ea0bd44c49f0cf3abf1a6d40e",
|
||||||
|
"a4ea8e0c4cac4d5e88b7e3f527e4fe90",
|
||||||
|
"571afc0f4b2840c9830d6b5a307ed1f9",
|
||||||
|
"6ec593cab5b64f0ea638bb175b9daa5c",
|
||||||
|
"77a52aed00ae408bb24524880e19ec8a",
|
||||||
|
"0b2de4b29b4b44fe9d96361a40c793d0",
|
||||||
|
"3c5b5fb1a5ac468a89c1058bd90cfb58",
|
||||||
|
"e53e0a2a240e43cfa562c89b3d703dea",
|
||||||
|
"35966343cf9249ef8bc028a0d5c5f97d",
|
||||||
|
"e36a37e0d41c47ccb8bc6d56c19fb17c",
|
||||||
|
"279ccf7de43847a1a6579c9182a46cc8",
|
||||||
|
"41b5d6ab0b7d43c790a55f125c0e7494"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"id": "UsQJdAgkLp9n",
|
||||||
|
"outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"device = \"cuda\"\n",
|
||||||
|
"# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
|
||||||
|
"checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
|
||||||
|
"mimi = checkpoint_info.get_mimi(device=device)\n",
|
||||||
|
"text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
|
||||||
|
"lm = checkpoint_info.get_moshi(device=device)\n",
|
||||||
|
"in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
|
||||||
|
"in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
|
||||||
|
"\n",
|
||||||
|
"stt_config = checkpoint_info.stt_config\n",
|
||||||
|
"pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
|
||||||
|
"pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
|
||||||
|
"in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
|
||||||
|
"in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
|
||||||
|
"\n",
|
||||||
|
"state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
|
||||||
|
"text = state.run(in_pcms)\n",
|
||||||
|
"print(textwrap.fill(text, width=100))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/",
|
||||||
|
"height": 75
|
||||||
|
},
|
||||||
|
"id": "CIAXs9oaPrtj",
|
||||||
|
"outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from IPython.display import Audio\n",
|
||||||
|
"\n",
|
||||||
|
"Audio(\"sample_fr_hibiki_crepes.mp3\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "qkUZ6CBKOdTa"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "L4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
140
tts_pytorch.ipynb
Normal file
140
tts_pytorch.ipynb
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Fast install, might break in the future.\n",
|
||||||
|
"!pip install 'sphn<0.2'\n",
|
||||||
|
"!pip install --no-deps \"moshi==0.2.11\"\n",
|
||||||
|
"# Slow install (will download torch and cuda), but future proof.\n",
|
||||||
|
"# !pip install \"moshi==0.2.11\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import argparse\n",
|
||||||
|
"import sys\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from moshi.models.loaders import CheckpointInfo\n",
|
||||||
|
"from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
|
||||||
|
"\n",
|
||||||
|
"from IPython.display import display, Audio"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Configuration\n",
|
||||||
|
"text = \"Hey there! How are you? I had the craziest day today.\"\n",
|
||||||
|
"voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n",
|
||||||
|
"print(f\"See https://huggingface.co/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Set everything up\n",
|
||||||
|
"checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
|
||||||
|
"tts_model = TTSModel.from_checkpoint_info(\n",
|
||||||
|
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
|
||||||
|
"entries = tts_model.prepare_script([text], padding_between=1)\n",
|
||||||
|
"voice_path = tts_model.get_voice_path(voice)\n",
|
||||||
|
"# CFG coef goes here because the model was trained with CFG distillation,\n",
|
||||||
|
"# so it's not _actually_ doing CFG at inference time.\n",
|
||||||
|
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
|
||||||
|
"condition_attributes = tts_model.make_condition_attributes(\n",
|
||||||
|
" [voice_path], cfg_coef=2.0\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(\"Generating audio...\")\n",
|
||||||
|
"\n",
|
||||||
|
"pcms = []\n",
|
||||||
|
"def _on_frame(frame):\n",
|
||||||
|
" print(\"Step\", len(pcms), end=\"\\r\")\n",
|
||||||
|
" if (frame != -1).all():\n",
|
||||||
|
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
|
||||||
|
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
|
||||||
|
"\n",
|
||||||
|
"# You could also generate multiple audios at once by extending the following lists.\n",
|
||||||
|
"all_entries = [entries]\n",
|
||||||
|
"all_condition_attributes = [condition_attributes]\n",
|
||||||
|
"with tts_model.mimi.streaming(len(all_entries)):\n",
|
||||||
|
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Done generating.\")\n",
|
||||||
|
"audio = np.concatenate(pcms, axis=-1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"display(\n",
|
||||||
|
" Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user