Rename examples and add pre-commit

This commit is contained in:
Václav Volhejn 2025-06-25 10:50:14 +02:00
parent 8bd3f59631
commit 7b818c2636
8 changed files with 347 additions and 58 deletions

22
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,22 @@
repos:
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ["--maxkb=2048"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
types_or: [python, pyi] # Don't run on `jupyter` files
args: [--fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi] # Don't run on `jupyter` files

View File

@ -47,6 +47,10 @@ wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hib
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
[see our Colab notebook](https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing).
If you just want to run the model on a file, you can use `moshi.run_inference`.
This requires the [moshi package](https://pypi.org/project/moshi/)
with version 0.2.6 or later, which can be installed via pip.
@ -107,7 +111,7 @@ moshi-server worker --config configs/config-stt-en_fr-hf.toml
Once the server has started you can run a streaming inference with the following
script.
```bash
uv run scripts/asr-streaming-query.py bria.mp3
uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3
```
The script limits the decoding speed to simulates real-time processing of the audio.
@ -166,3 +170,14 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo
the MIT license.
The weights for the speech-to-text models are released under the CC-BY 4.0 license.
## Developing
Install the [pre-commit hooks](https://pre-commit.com/) by running:
```bash
pip install pre-commit
pre-commit install
```
If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.

View File

@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
import dataclasses
import julius
import jiwer
from datasets import load_dataset, Dataset
from whisper.normalizers import EnglishTextNormalizer
import argparse
import torch
import moshi.models
import tqdm
import dataclasses
import time
import jiwer
import julius
import moshi.models
import torch
import tqdm
from datasets import Dataset, load_dataset
from whisper.normalizers import EnglishTextNormalizer
_NORMALIZER = EnglishTextNormalizer()
@ -120,9 +119,9 @@ class AsrMetrics:
self.num_sequences += 1
def compute(self) -> dict:
assert (
self.num_sequences > 0
), "Unable to compute with total number of comparisons <= 0" # type: ignore
assert self.num_sequences > 0, (
"Unable to compute with total number of comparisons <= 0"
) # type: ignore
return {
"cer": (self.cer_sum / self.num_sequences),
"wer": (self.wer_sum / self.num_sequences),

View File

@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \
```
"""
import itertools
import dataclasses
import julius
import sphn
import argparse
import dataclasses
import itertools
import math
import torch
import julius
import moshi.models
import sphn
import torch
import tqdm

View File

@ -10,17 +10,16 @@
import argparse
import asyncio
import json
import msgpack
import sphn
import struct
import time
import numpy as np
import msgpack
import sphn
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
@ -44,11 +43,13 @@ async def receive_messages(websocket):
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
transcript.append({
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
})
transcript.append(
{
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
}
)
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
chunk = {
"type": "Audio",
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float):
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
)
await websocket.send(msg)
for _ in range(35):
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float):
print("Connection closed while sending messages.")
async def stream_audio(url: str, rtf: float, api_key: str):
async def stream_audio(url: str, rtf: float):
"""Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
@ -115,7 +121,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument("--api-key", default="open_token")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
@ -125,7 +130,7 @@ if __name__ == "__main__":
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf, args.api_key))
asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:

View File

@ -11,19 +11,16 @@
# ///
import argparse
from dataclasses import dataclass
import json
import numpy as np
import queue
import sounddevice as sd
from huggingface_hub import hf_hub_download
import mlx.core as mx
import mlx.nn as nn
from moshi_mlx import models, utils
import rustymimi
import sentencepiece
import sounddevice as sd
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -69,6 +66,7 @@ if __name__ == "__main__":
)
block_queue = queue.Queue()
def audio_callback(indata, _frames, _time, _status):
block_queue.put(indata.copy())
@ -84,7 +82,9 @@ if __name__ == "__main__":
block = block_queue.get()
block = block[None, :, 0]
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks]
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks
]
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
@ -93,4 +93,3 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)

View File

@ -9,9 +9,9 @@
# ///
import argparse
import asyncio
import msgpack
import signal
import msgpack
import numpy as np
import sounddevice as sd
import websockets
@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
audio_queue = asyncio.Queue()
async def receive_messages(websocket):
"""Receive and process messages from the WebSocket server."""
try:
@ -47,22 +48,26 @@ async def send_messages(websocket):
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
loop = asyncio.get_event_loop()
def audio_callback(indata, frames, time, status):
loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy())
# Start audio stream
def audio_callback(indata, frames, time, status):
loop.call_soon_threadsafe(
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
)
# Start audio stream
with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE,
channels=TARGET_CHANNELS,
dtype='float32',
dtype="float32",
callback=audio_callback,
blocksize=1920 # 80ms blocks
blocksize=1920, # 80ms blocks
):
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
@ -79,8 +84,12 @@ if __name__ == "__main__":
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="open_token")
parser.add_argument("--list-devices", action="store_true", help="List available audio devices")
parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)")
parser.add_argument(
"--list-devices", action="store_true", help="List available audio devices"
)
parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)"
)
args = parser.parse_args()

View File

@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gJEMjPgeI-rw",
"outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
},
"outputs": [],
"source": [
"!pip install moshi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CA4K5iDFJcqJ",
"outputId": "b609843a-a193-4729-b099-5f8780532333"
},
"outputs": [],
"source": [
"!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VA3Haix3IZ8Q"
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"import time\n",
"import sentencepiece\n",
"import sphn\n",
"import textwrap\n",
"import torch\n",
"\n",
"from moshi.models import loaders, MimiModel, LMModel, LMGen"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9AK5zBMTI9bw"
},
"outputs": [],
"source": [
"@dataclass\n",
"class InferenceState:\n",
" mimi: MimiModel\n",
" text_tokenizer: sentencepiece.SentencePieceProcessor\n",
" lm_gen: LMGen\n",
"\n",
" def __init__(\n",
" self,\n",
" mimi: MimiModel,\n",
" text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
" lm: LMModel,\n",
" batch_size: int,\n",
" device: str | torch.device,\n",
" ):\n",
" self.mimi = mimi\n",
" self.text_tokenizer = text_tokenizer\n",
" self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
" self.device = device\n",
" self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
" self.batch_size = batch_size\n",
" self.mimi.streaming_forever(batch_size)\n",
" self.lm_gen.streaming_forever(batch_size)\n",
"\n",
" def run(self, in_pcms: torch.Tensor):\n",
" device = self.lm_gen.lm_model.device\n",
" ntokens = 0\n",
" first_frame = True\n",
" chunks = [\n",
" c\n",
" for c in in_pcms.split(self.frame_size, dim=2)\n",
" if c.shape[-1] == self.frame_size\n",
" ]\n",
" start_time = time.time()\n",
" all_text = []\n",
" for chunk in chunks:\n",
" codes = self.mimi.encode(chunk)\n",
" if first_frame:\n",
" # Ensure that the first slice of codes is properly seen by the transformer\n",
" # as otherwise the first slice is replaced by the initial tokens.\n",
" tokens = self.lm_gen.step(codes)\n",
" first_frame = False\n",
" tokens = self.lm_gen.step(codes)\n",
" if tokens is None:\n",
" continue\n",
" assert tokens.shape[1] == 1\n",
" one_text = tokens[0, 0].cpu()\n",
" if one_text.item() not in [0, 3]:\n",
" text = self.text_tokenizer.id_to_piece(one_text.item())\n",
" text = text.replace(\"▁\", \" \")\n",
" all_text.append(text)\n",
" ntokens += 1\n",
" dt = time.time() - start_time\n",
" print(\n",
" f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
" )\n",
" return \"\".join(all_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 353,
"referenced_widgets": [
"0a5f6f887e2b4cd1990a0e9ec0153ed9",
"f7893826fcba4bdc87539589d669249b",
"8805afb12c484781be85082ff02dad13",
"97679c0d9ab44bed9a3456f2fcb541fd",
"d73c0321bed54a52b5e1da0a7788e32a",
"d67be13a920d4fc89e5570b5b29fc1d2",
"6b377c2d7bf945fb89e46c39d246a332",
"b82ff365c78e41ad8094b46daf79449d",
"477aa7fa82dc42d5bce6f1743c45d626",
"cbd288510c474430beb66f346f382c45",
"aafc347cdf28428ea6a7abe5b46b726f",
"fca09acd5d0d45468c8b04bfb2de7646",
"79e35214b51b4a9e9b3f7144b0b34f7b",
"89e9a37f69904bd48b954d627bff6687",
"57028789c78248a7b0ad4f031c9545c9",
"1150fcb427994c2984d4d0f4e4745fe5",
"e24b1fc52f294f849019c9b3befb613f",
"8724878682cf4c3ca992667c45009398",
"36a22c977d5242008871310133b7d2af",
"5b3683cad5cb4877b43fadd003edf97f",
"703f98272e4d469d8f27f5a465715dd8",
"9dbe02ef5fac41cfaee3d02946e65c88",
"37faa87ad03a4271992c21ce6a629e18",
"570c547e48cd421b814b2c5e028e4c0b",
"b173768580fc4c0a8e3abf272e4c363a",
"e57d1620f0a9427b85d8b4885ef4e8e3",
"5dd4474df70743498b616608182714dd",
"cc907676a65f4ad1bf68a77b4a00e89b",
"a34abc3b118e4305951a466919c28ff6",
"a77ccfcdb90146c7a63b4b2d232bc494",
"f7313e6e3a27475993cab3961d6ae363",
"39b47fad9c554839868fe9e4bbf7def2",
"14e9511ea0bd44c49f0cf3abf1a6d40e",
"a4ea8e0c4cac4d5e88b7e3f527e4fe90",
"571afc0f4b2840c9830d6b5a307ed1f9",
"6ec593cab5b64f0ea638bb175b9daa5c",
"77a52aed00ae408bb24524880e19ec8a",
"0b2de4b29b4b44fe9d96361a40c793d0",
"3c5b5fb1a5ac468a89c1058bd90cfb58",
"e53e0a2a240e43cfa562c89b3d703dea",
"35966343cf9249ef8bc028a0d5c5f97d",
"e36a37e0d41c47ccb8bc6d56c19fb17c",
"279ccf7de43847a1a6579c9182a46cc8",
"41b5d6ab0b7d43c790a55f125c0e7494"
]
},
"id": "UsQJdAgkLp9n",
"outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
},
"outputs": [],
"source": [
"device = \"cuda\"\n",
"# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
"checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
"mimi = checkpoint_info.get_mimi(device=device)\n",
"text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
"lm = checkpoint_info.get_moshi(device=device)\n",
"in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
"in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
"\n",
"stt_config = checkpoint_info.stt_config\n",
"pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
"pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
"in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
"in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
"\n",
"state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
"text = state.run(in_pcms)\n",
"print(textwrap.fill(text, width=100))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "CIAXs9oaPrtj",
"outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"\n",
"Audio(\"sample_fr_hibiki_crepes.mp3\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qkUZ6CBKOdTa"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "L4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}