Compare commits

...

3 Commits

Author SHA1 Message Date
Václav Volhejn
8c18018b6f Link to colab notebook via github 2025-06-25 11:27:33 +02:00
Václav Volhejn
bb0bdbf697 Fix references to scripts, add implementations overview 2025-06-25 11:15:46 +02:00
Václav Volhejn
7b818c2636 Rename examples and add pre-commit 2025-06-25 10:50:14 +02:00
8 changed files with 374 additions and 66 deletions

22
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,22 @@
repos:
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ["--maxkb=2048"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
types_or: [python, pyi] # Don't run on `jupyter` files
args: [--fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi] # Don't run on `jupyter` files

View File

@ -1,7 +1,7 @@
<a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
@ -33,6 +33,21 @@ These speech-to-text models have several advantages:
can be used to detect when the user is speaking. This is especially useful
for building voice agents.
### Implementations overview
We provide different implementations of Kyutai STT for different use cases.
Here is how to choose which one to use:
- **PyTorch: for research and tinkering.**
If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
- **Rust: for production.**
If you want to serve Kyutai STT in a production setting, use our Rust server.
Our robust Rust server provides streaming access to the model over websockets.
We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x.
- **MLX: for on-device inference on iPhone and Mac.**
MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
You can retrieve the sample files used in the following snippets via:
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
@ -43,10 +58,14 @@ wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hib
<a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
<a target="_blank" href="https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing">
<a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
[see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/transcribe_via_pytorch.ipynb).
If you just want to run the model on a file, you can use `moshi.run_inference`.
This requires the [moshi package](https://pypi.org/project/moshi/)
with version 0.2.6 or later, which can be installed via pip.
@ -58,25 +77,25 @@ If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the install
```bash
uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3
```
It will install the moshi package in a temporary environment and run the speech-to-text.
Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
```bash
uv run \
scripts/streaming_stt_timestamps.py \
scripts/transcribe_from_file_via_pytorch.py \
--hf-repo kyutai/stt-2.6b-en \
--file bria.mp3
```
The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics:
```bash
uv run scripts/streaming_stt.py \
uv run scripts/evaluate_on_dataset.py \
--dataset meanwhile \
--hf-repo kyutai/stt-2.6b-en
```
### Rust server
<a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
</a>
@ -104,15 +123,19 @@ and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`,
moshi-server worker --config configs/config-stt-en_fr-hf.toml
```
Once the server has started you can run a streaming inference with the following
script.
Once the server has started you can transcribe audio from your microphone with the following script.
```bash
uv run scripts/asr-streaming-query.py bria.mp3
uv run scripts/transcribe_from_mic_via_rust_server.py
```
We also provide a script for transcribing from an audio file.
```bash
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
```
The script limits the decoding speed to simulates real-time processing of the audio.
Faster processing can be triggered by setting
the real-time factor, e.g. `--rtf 500` will process
the real-time factor, e.g. `--rtf 1000` will process
the data as fast as possible.
### Rust standalone
@ -166,3 +189,14 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo
the MIT license.
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
## Developing
Install the [pre-commit hooks](https://pre-commit.com/) by running:
```bash
pip install pre-commit
pre-commit install
```
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.

View File

@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
import dataclasses
import julius
import jiwer
from datasets import load_dataset, Dataset
from whisper.normalizers import EnglishTextNormalizer
import argparse
import torch
import moshi.models
import tqdm
import dataclasses
import time
import jiwer
import julius
import moshi.models
import torch
import tqdm
from datasets import Dataset, load_dataset
from whisper.normalizers import EnglishTextNormalizer
_NORMALIZER = EnglishTextNormalizer()
@ -120,9 +119,9 @@ class AsrMetrics:
self.num_sequences += 1
def compute(self) -> dict:
assert (
self.num_sequences > 0
), "Unable to compute with total number of comparisons <= 0" # type: ignore
assert self.num_sequences > 0, (
"Unable to compute with total number of comparisons <= 0"
) # type: ignore
return {
"cer": (self.cer_sum / self.num_sequences),
"wer": (self.wer_sum / self.num_sequences),

View File

@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \
```
"""
import itertools
import dataclasses
import julius
import sphn
import argparse
import dataclasses
import itertools
import math
import torch
import julius
import moshi.models
import sphn
import torch
import tqdm

View File

@ -10,17 +10,16 @@
import argparse
import asyncio
import json
import msgpack
import sphn
import struct
import time
import numpy as np
import msgpack
import sphn
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
@ -44,11 +43,13 @@ async def receive_messages(websocket):
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
transcript.append({
transcript.append(
{
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
})
}
)
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
chunk = {
"type": "Audio",
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float):
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
)
await websocket.send(msg)
for _ in range(35):
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float):
print("Connection closed while sending messages.")
async def stream_audio(url: str, rtf: float, api_key: str):
async def stream_audio(url: str, rtf: float):
"""Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
@ -115,7 +121,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument("--api-key", default="open_token")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
@ -125,7 +130,7 @@ if __name__ == "__main__":
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf, args.api_key))
asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:

View File

@ -11,19 +11,16 @@
# ///
import argparse
from dataclasses import dataclass
import json
import numpy as np
import queue
import sounddevice as sd
from huggingface_hub import hf_hub_download
import mlx.core as mx
import mlx.nn as nn
from moshi_mlx import models, utils
import rustymimi
import sentencepiece
import sounddevice as sd
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -69,6 +66,7 @@ if __name__ == "__main__":
)
block_queue = queue.Queue()
def audio_callback(indata, _frames, _time, _status):
block_queue.put(indata.copy())
@ -84,7 +82,9 @@ if __name__ == "__main__":
block = block_queue.get()
block = block[None, :, 0]
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks]
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks
]
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
@ -93,4 +93,3 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)

View File

@ -9,9 +9,9 @@
# ///
import argparse
import asyncio
import msgpack
import signal
import msgpack
import numpy as np
import sounddevice as sd
import websockets
@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
audio_queue = asyncio.Queue()
async def receive_messages(websocket):
"""Receive and process messages from the WebSocket server."""
try:
@ -47,22 +48,26 @@ async def send_messages(websocket):
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
loop = asyncio.get_event_loop()
def audio_callback(indata, frames, time, status):
loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy())
loop.call_soon_threadsafe(
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
)
# Start audio stream
with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE,
channels=TARGET_CHANNELS,
dtype='float32',
dtype="float32",
callback=audio_callback,
blocksize=1920 # 80ms blocks
blocksize=1920, # 80ms blocks
):
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
@ -79,8 +84,12 @@ if __name__ == "__main__":
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="open_token")
parser.add_argument("--list-devices", action="store_true", help="List available audio devices")
parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)")
parser.add_argument(
"--list-devices", action="store_true", help="List available audio devices"
)
parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)"
)
args = parser.parse_args()

View File

@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gJEMjPgeI-rw",
"outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
},
"outputs": [],
"source": [
"!pip install moshi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CA4K5iDFJcqJ",
"outputId": "b609843a-a193-4729-b099-5f8780532333"
},
"outputs": [],
"source": [
"!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VA3Haix3IZ8Q"
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"import time\n",
"import sentencepiece\n",
"import sphn\n",
"import textwrap\n",
"import torch\n",
"\n",
"from moshi.models import loaders, MimiModel, LMModel, LMGen"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9AK5zBMTI9bw"
},
"outputs": [],
"source": [
"@dataclass\n",
"class InferenceState:\n",
" mimi: MimiModel\n",
" text_tokenizer: sentencepiece.SentencePieceProcessor\n",
" lm_gen: LMGen\n",
"\n",
" def __init__(\n",
" self,\n",
" mimi: MimiModel,\n",
" text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
" lm: LMModel,\n",
" batch_size: int,\n",
" device: str | torch.device,\n",
" ):\n",
" self.mimi = mimi\n",
" self.text_tokenizer = text_tokenizer\n",
" self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
" self.device = device\n",
" self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
" self.batch_size = batch_size\n",
" self.mimi.streaming_forever(batch_size)\n",
" self.lm_gen.streaming_forever(batch_size)\n",
"\n",
" def run(self, in_pcms: torch.Tensor):\n",
" device = self.lm_gen.lm_model.device\n",
" ntokens = 0\n",
" first_frame = True\n",
" chunks = [\n",
" c\n",
" for c in in_pcms.split(self.frame_size, dim=2)\n",
" if c.shape[-1] == self.frame_size\n",
" ]\n",
" start_time = time.time()\n",
" all_text = []\n",
" for chunk in chunks:\n",
" codes = self.mimi.encode(chunk)\n",
" if first_frame:\n",
" # Ensure that the first slice of codes is properly seen by the transformer\n",
" # as otherwise the first slice is replaced by the initial tokens.\n",
" tokens = self.lm_gen.step(codes)\n",
" first_frame = False\n",
" tokens = self.lm_gen.step(codes)\n",
" if tokens is None:\n",
" continue\n",
" assert tokens.shape[1] == 1\n",
" one_text = tokens[0, 0].cpu()\n",
" if one_text.item() not in [0, 3]:\n",
" text = self.text_tokenizer.id_to_piece(one_text.item())\n",
" text = text.replace(\"▁\", \" \")\n",
" all_text.append(text)\n",
" ntokens += 1\n",
" dt = time.time() - start_time\n",
" print(\n",
" f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
" )\n",
" return \"\".join(all_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 353,
"referenced_widgets": [
"0a5f6f887e2b4cd1990a0e9ec0153ed9",
"f7893826fcba4bdc87539589d669249b",
"8805afb12c484781be85082ff02dad13",
"97679c0d9ab44bed9a3456f2fcb541fd",
"d73c0321bed54a52b5e1da0a7788e32a",
"d67be13a920d4fc89e5570b5b29fc1d2",
"6b377c2d7bf945fb89e46c39d246a332",
"b82ff365c78e41ad8094b46daf79449d",
"477aa7fa82dc42d5bce6f1743c45d626",
"cbd288510c474430beb66f346f382c45",
"aafc347cdf28428ea6a7abe5b46b726f",
"fca09acd5d0d45468c8b04bfb2de7646",
"79e35214b51b4a9e9b3f7144b0b34f7b",
"89e9a37f69904bd48b954d627bff6687",
"57028789c78248a7b0ad4f031c9545c9",
"1150fcb427994c2984d4d0f4e4745fe5",
"e24b1fc52f294f849019c9b3befb613f",
"8724878682cf4c3ca992667c45009398",
"36a22c977d5242008871310133b7d2af",
"5b3683cad5cb4877b43fadd003edf97f",
"703f98272e4d469d8f27f5a465715dd8",
"9dbe02ef5fac41cfaee3d02946e65c88",
"37faa87ad03a4271992c21ce6a629e18",
"570c547e48cd421b814b2c5e028e4c0b",
"b173768580fc4c0a8e3abf272e4c363a",
"e57d1620f0a9427b85d8b4885ef4e8e3",
"5dd4474df70743498b616608182714dd",
"cc907676a65f4ad1bf68a77b4a00e89b",
"a34abc3b118e4305951a466919c28ff6",
"a77ccfcdb90146c7a63b4b2d232bc494",
"f7313e6e3a27475993cab3961d6ae363",
"39b47fad9c554839868fe9e4bbf7def2",
"14e9511ea0bd44c49f0cf3abf1a6d40e",
"a4ea8e0c4cac4d5e88b7e3f527e4fe90",
"571afc0f4b2840c9830d6b5a307ed1f9",
"6ec593cab5b64f0ea638bb175b9daa5c",
"77a52aed00ae408bb24524880e19ec8a",
"0b2de4b29b4b44fe9d96361a40c793d0",
"3c5b5fb1a5ac468a89c1058bd90cfb58",
"e53e0a2a240e43cfa562c89b3d703dea",
"35966343cf9249ef8bc028a0d5c5f97d",
"e36a37e0d41c47ccb8bc6d56c19fb17c",
"279ccf7de43847a1a6579c9182a46cc8",
"41b5d6ab0b7d43c790a55f125c0e7494"
]
},
"id": "UsQJdAgkLp9n",
"outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
},
"outputs": [],
"source": [
"device = \"cuda\"\n",
"# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
"checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
"mimi = checkpoint_info.get_mimi(device=device)\n",
"text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
"lm = checkpoint_info.get_moshi(device=device)\n",
"in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
"in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
"\n",
"stt_config = checkpoint_info.stt_config\n",
"pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
"pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
"in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
"in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
"\n",
"state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
"text = state.run(in_pcms)\n",
"print(textwrap.fill(text, width=100))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "CIAXs9oaPrtj",
"outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"\n",
"Audio(\"sample_fr_hibiki_crepes.mp3\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qkUZ6CBKOdTa"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "L4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}