From 7b818c263697dfe00c435e1c45c0631b1f3a134c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=A1clav=20Volhejn?= Date: Wed, 25 Jun 2025 10:50:14 +0200 Subject: [PATCH] Rename examples and add pre-commit --- .pre-commit-config.yaml | 22 ++ README.md | 17 +- ...treaming_stt.py => evaluate_on_dataset.py} | 23 +- ...py => transcribe_from_file_via_pytorch.py} | 10 +- ...> transcribe_from_file_via_rust_server.py} | 45 ++-- ...-mic.py => transcribe_from_mic_via_mlx.py} | 15 +- ...=> transcribe_from_mic_via_rust_server.py} | 33 ++- transcribe_via_pytorch.ipynb | 240 ++++++++++++++++++ 8 files changed, 347 insertions(+), 58 deletions(-) create mode 100644 .pre-commit-config.yaml rename scripts/{streaming_stt.py => evaluate_on_dataset.py} (98%) rename scripts/{streaming_stt_timestamps.py => transcribe_from_file_via_pytorch.py} (100%) rename scripts/{asr-streaming-query.py => transcribe_from_file_via_rust_server.py} (78%) rename scripts/{mlx-mic.py => transcribe_from_mic_via_mlx.py} (97%) rename scripts/{mic-query.py => transcribe_from_mic_via_rust_server.py} (86%) create mode 100644 transcribe_via_pytorch.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e2f0230 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +repos: + # Get rid of Jupyter Notebook output because we don't want to keep it in Git + - repo: https://github.com/kynan/nbstripout + rev: 0.8.1 + hooks: + - id: nbstripout + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-added-large-files + args: ["--maxkb=2048"] + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.11.7 + hooks: + # Run the linter. + - id: ruff + types_or: [python, pyi] # Don't run on `jupyter` files + args: [--fix] + # Run the formatter. + - id: ruff-format + types_or: [python, pyi] # Don't run on `jupyter` files diff --git a/README.md b/README.md index 99ed1bd..ffd0dea 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,10 @@ wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hib Open In Colab +For an example of how to use the model in a way where you can directly stream in PyTorch tensors, +[see our Colab notebook](https://colab.research.google.com/drive/1mc0Q-FoHxU2pEvId8rTdS4q1r1zorJhS?usp=sharing). + +If you just want to run the model on a file, you can use `moshi.run_inference`. This requires the [moshi package](https://pypi.org/project/moshi/) with version 0.2.6 or later, which can be installed via pip. @@ -107,7 +111,7 @@ moshi-server worker --config configs/config-stt-en_fr-hf.toml Once the server has started you can run a streaming inference with the following script. ```bash -uv run scripts/asr-streaming-query.py bria.mp3 +uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3 ``` The script limits the decoding speed to simulates real-time processing of the audio. @@ -166,3 +170,14 @@ Note that parts of this code is based on [AudioCraft](https://github.com/faceboo the MIT license. The weights for the speech-to-text models are released under the CC-BY 4.0 license. + +## Developing + +Install the [pre-commit hooks](https://pre-commit.com/) by running: + +```bash +pip install pre-commit +pre-commit install +``` + +If you're using `uv`, you can replace the two commands with `uvx pre-commit install`. \ No newline at end of file diff --git a/scripts/streaming_stt.py b/scripts/evaluate_on_dataset.py similarity index 98% rename from scripts/streaming_stt.py rename to scripts/evaluate_on_dataset.py index bffe294..3bef8aa 100644 --- a/scripts/streaming_stt.py +++ b/scripts/evaluate_on_dataset.py @@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \ # Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34 # Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15 -import dataclasses -import julius -import jiwer -from datasets import load_dataset, Dataset -from whisper.normalizers import EnglishTextNormalizer import argparse - -import torch -import moshi.models -import tqdm +import dataclasses import time +import jiwer +import julius +import moshi.models +import torch +import tqdm +from datasets import Dataset, load_dataset +from whisper.normalizers import EnglishTextNormalizer _NORMALIZER = EnglishTextNormalizer() @@ -120,9 +119,9 @@ class AsrMetrics: self.num_sequences += 1 def compute(self) -> dict: - assert ( - self.num_sequences > 0 - ), "Unable to compute with total number of comparisons <= 0" # type: ignore + assert self.num_sequences > 0, ( + "Unable to compute with total number of comparisons <= 0" + ) # type: ignore return { "cer": (self.cer_sum / self.num_sequences), "wer": (self.wer_sum / self.num_sequences), diff --git a/scripts/streaming_stt_timestamps.py b/scripts/transcribe_from_file_via_pytorch.py similarity index 100% rename from scripts/streaming_stt_timestamps.py rename to scripts/transcribe_from_file_via_pytorch.py index 7e0970b..e941da8 100644 --- a/scripts/streaming_stt_timestamps.py +++ b/scripts/transcribe_from_file_via_pytorch.py @@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \ ``` """ -import itertools -import dataclasses -import julius -import sphn import argparse +import dataclasses +import itertools import math -import torch +import julius import moshi.models +import sphn +import torch import tqdm diff --git a/scripts/asr-streaming-query.py b/scripts/transcribe_from_file_via_rust_server.py similarity index 78% rename from scripts/asr-streaming-query.py rename to scripts/transcribe_from_file_via_rust_server.py index 0527556..3e90b79 100644 --- a/scripts/asr-streaming-query.py +++ b/scripts/transcribe_from_file_via_rust_server.py @@ -10,17 +10,16 @@ import argparse import asyncio import json -import msgpack -import sphn -import struct import time -import numpy as np +import msgpack +import sphn import websockets # Desired audio properties TARGET_SAMPLE_RATE = 24000 TARGET_CHANNELS = 1 # Mono +HEADERS = {"kyutai-api-key": "open_token"} all_text = [] transcript = [] finished = False @@ -44,11 +43,13 @@ async def receive_messages(websocket): print("received:", data) if data["type"] == "Word": all_text.append(data["text"]) - transcript.append({ - "speaker": "SPEAKER_00", - "text": data["text"], - "timestamp": [data["start_time"], data["start_time"]], - }) + transcript.append( + { + "speaker": "SPEAKER_00", + "text": data["text"], + "timestamp": [data["start_time"], data["start_time"]], + } + ) if data["type"] == "EndWord": if len(transcript) > 0: transcript[-1]["timestamp"][1] = data["stop_time"] @@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float): global finished audio_data = load_and_process_audio(args.in_file) try: - # Start with a second of silence - chunk = { "type": "Audio", "pcm": [0.0] * 24000 } + # Start with a second of silence. + # This is needed for the 2.6B model for technical reasons. + chunk = {"type": "Audio", "pcm": [0.0] * 24000} msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await websocket.send(msg) chunk_size = 1920 # Send data in chunks start_time = time.time() for i in range(0, len(audio_data), chunk_size): - chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] } + chunk = { + "type": "Audio", + "pcm": [float(x) for x in audio_data[i : i + chunk_size]], + } msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await websocket.send(msg) expected_send_time = start_time + (i + 1) / 24000 / rtf @@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float): await asyncio.sleep(expected_send_time - current_time) else: await asyncio.sleep(0.001) - chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 } + chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5} msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await websocket.send(msg) - msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True) + msg = msgpack.packb( + {"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True + ) await websocket.send(msg) for _ in range(35): - chunk = { "type": "Audio", "pcm": [0.0] * 1920 } + chunk = {"type": "Audio", "pcm": [0.0] * 1920} msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await websocket.send(msg) while True: @@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float): print("Connection closed while sending messages.") -async def stream_audio(url: str, rtf: float, api_key: str): +async def stream_audio(url: str, rtf: float): """Stream audio data to a WebSocket server.""" - headers = {"kyutai-api-key": api_key} - async with websockets.connect(url, additional_headers=headers) as websocket: + async with websockets.connect(url, additional_headers=HEADERS) as websocket: send_task = asyncio.create_task(send_messages(websocket, rtf)) receive_task = asyncio.create_task(receive_messages(websocket)) await asyncio.gather(send_task, receive_task) @@ -115,7 +121,6 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("in_file") parser.add_argument("--transcript") - parser.add_argument("--api-key", default="open_token") parser.add_argument( "--url", help="The url of the server to which to send the audio", @@ -125,7 +130,7 @@ if __name__ == "__main__": args = parser.parse_args() url = f"{args.url}/api/asr-streaming" - asyncio.run(stream_audio(url, args.rtf, args.api_key)) + asyncio.run(stream_audio(url, args.rtf)) print(" ".join(all_text)) if args.transcript is not None: with open(args.transcript, "w") as fobj: diff --git a/scripts/mlx-mic.py b/scripts/transcribe_from_mic_via_mlx.py similarity index 97% rename from scripts/mlx-mic.py rename to scripts/transcribe_from_mic_via_mlx.py index 1816efe..e8792e2 100644 --- a/scripts/mlx-mic.py +++ b/scripts/transcribe_from_mic_via_mlx.py @@ -11,19 +11,16 @@ # /// import argparse -from dataclasses import dataclass import json -import numpy as np import queue -import sounddevice as sd -from huggingface_hub import hf_hub_download import mlx.core as mx import mlx.nn as nn -from moshi_mlx import models, utils import rustymimi import sentencepiece - +import sounddevice as sd +from huggingface_hub import hf_hub_download +from moshi_mlx import models, utils if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -69,6 +66,7 @@ if __name__ == "__main__": ) block_queue = queue.Queue() + def audio_callback(indata, _frames, _time, _status): block_queue.put(indata.copy()) @@ -84,7 +82,9 @@ if __name__ == "__main__": block = block_queue.get() block = block[None, :, 0] other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1]) - other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks] + other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[ + :, :, :other_codebooks + ] text_token = gen.step(other_audio_tokens[0]) text_token = text_token[0].item() audio_tokens = gen.last_audio_tokens() @@ -93,4 +93,3 @@ if __name__ == "__main__": _text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") print(_text, end="", flush=True) - diff --git a/scripts/mic-query.py b/scripts/transcribe_from_mic_via_rust_server.py similarity index 86% rename from scripts/mic-query.py rename to scripts/transcribe_from_mic_via_rust_server.py index ff41d59..85bad93 100644 --- a/scripts/mic-query.py +++ b/scripts/transcribe_from_mic_via_rust_server.py @@ -9,9 +9,9 @@ # /// import argparse import asyncio -import msgpack import signal +import msgpack import numpy as np import sounddevice as sd import websockets @@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000 TARGET_CHANNELS = 1 # Mono audio_queue = asyncio.Queue() + async def receive_messages(websocket): """Receive and process messages from the WebSocket server.""" try: @@ -47,22 +48,26 @@ async def send_messages(websocket): except websockets.ConnectionClosed: print("Connection closed while sending messages.") + async def stream_audio(url: str, api_key: str): """Stream audio data to a WebSocket server.""" print("Starting microphone recording...") print("Press Ctrl+C to stop recording") loop = asyncio.get_event_loop() - def audio_callback(indata, frames, time, status): - loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()) - # Start audio stream + def audio_callback(indata, frames, time, status): + loop.call_soon_threadsafe( + audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy() + ) + + # Start audio stream with sd.InputStream( samplerate=TARGET_SAMPLE_RATE, channels=TARGET_CHANNELS, - dtype='float32', + dtype="float32", callback=audio_callback, - blocksize=1920 # 80ms blocks + blocksize=1920, # 80ms blocks ): headers = {"kyutai-api-key": api_key} async with websockets.connect(url, additional_headers=headers) as websocket: @@ -79,11 +84,15 @@ if __name__ == "__main__": default="ws://127.0.0.1:8080", ) parser.add_argument("--api-key", default="open_token") - parser.add_argument("--list-devices", action="store_true", help="List available audio devices") - parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)") - + parser.add_argument( + "--list-devices", action="store_true", help="List available audio devices" + ) + parser.add_argument( + "--device", type=int, help="Input device ID (use --list-devices to see options)" + ) + args = parser.parse_args() - + def handle_sigint(signum, frame): print("Interrupted by user") exit(0) @@ -94,9 +103,9 @@ if __name__ == "__main__": print("Available audio devices:") print(sd.query_devices()) exit(0) - + if args.device is not None: sd.default.device[0] = args.device # Set input device - + url = f"{args.url}/api/asr-streaming" asyncio.run(stream_audio(url, args.api_key)) diff --git a/transcribe_via_pytorch.ipynb b/transcribe_via_pytorch.ipynb new file mode 100644 index 0000000..4210d64 --- /dev/null +++ b/transcribe_via_pytorch.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gJEMjPgeI-rw", + "outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593" + }, + "outputs": [], + "source": [ + "!pip install moshi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CA4K5iDFJcqJ", + "outputId": "b609843a-a193-4729-b099-5f8780532333" + }, + "outputs": [], + "source": [ + "!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VA3Haix3IZ8Q" + }, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "import time\n", + "import sentencepiece\n", + "import sphn\n", + "import textwrap\n", + "import torch\n", + "\n", + "from moshi.models import loaders, MimiModel, LMModel, LMGen" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9AK5zBMTI9bw" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class InferenceState:\n", + " mimi: MimiModel\n", + " text_tokenizer: sentencepiece.SentencePieceProcessor\n", + " lm_gen: LMGen\n", + "\n", + " def __init__(\n", + " self,\n", + " mimi: MimiModel,\n", + " text_tokenizer: sentencepiece.SentencePieceProcessor,\n", + " lm: LMModel,\n", + " batch_size: int,\n", + " device: str | torch.device,\n", + " ):\n", + " self.mimi = mimi\n", + " self.text_tokenizer = text_tokenizer\n", + " self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n", + " self.device = device\n", + " self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n", + " self.batch_size = batch_size\n", + " self.mimi.streaming_forever(batch_size)\n", + " self.lm_gen.streaming_forever(batch_size)\n", + "\n", + " def run(self, in_pcms: torch.Tensor):\n", + " device = self.lm_gen.lm_model.device\n", + " ntokens = 0\n", + " first_frame = True\n", + " chunks = [\n", + " c\n", + " for c in in_pcms.split(self.frame_size, dim=2)\n", + " if c.shape[-1] == self.frame_size\n", + " ]\n", + " start_time = time.time()\n", + " all_text = []\n", + " for chunk in chunks:\n", + " codes = self.mimi.encode(chunk)\n", + " if first_frame:\n", + " # Ensure that the first slice of codes is properly seen by the transformer\n", + " # as otherwise the first slice is replaced by the initial tokens.\n", + " tokens = self.lm_gen.step(codes)\n", + " first_frame = False\n", + " tokens = self.lm_gen.step(codes)\n", + " if tokens is None:\n", + " continue\n", + " assert tokens.shape[1] == 1\n", + " one_text = tokens[0, 0].cpu()\n", + " if one_text.item() not in [0, 3]:\n", + " text = self.text_tokenizer.id_to_piece(one_text.item())\n", + " text = text.replace(\"▁\", \" \")\n", + " all_text.append(text)\n", + " ntokens += 1\n", + " dt = time.time() - start_time\n", + " print(\n", + " f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n", + " )\n", + " return \"\".join(all_text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 353, + "referenced_widgets": [ + "0a5f6f887e2b4cd1990a0e9ec0153ed9", + "f7893826fcba4bdc87539589d669249b", + "8805afb12c484781be85082ff02dad13", + "97679c0d9ab44bed9a3456f2fcb541fd", + "d73c0321bed54a52b5e1da0a7788e32a", + "d67be13a920d4fc89e5570b5b29fc1d2", + "6b377c2d7bf945fb89e46c39d246a332", + "b82ff365c78e41ad8094b46daf79449d", + "477aa7fa82dc42d5bce6f1743c45d626", + "cbd288510c474430beb66f346f382c45", + "aafc347cdf28428ea6a7abe5b46b726f", + "fca09acd5d0d45468c8b04bfb2de7646", + "79e35214b51b4a9e9b3f7144b0b34f7b", + "89e9a37f69904bd48b954d627bff6687", + "57028789c78248a7b0ad4f031c9545c9", + "1150fcb427994c2984d4d0f4e4745fe5", + "e24b1fc52f294f849019c9b3befb613f", + "8724878682cf4c3ca992667c45009398", + "36a22c977d5242008871310133b7d2af", + "5b3683cad5cb4877b43fadd003edf97f", + "703f98272e4d469d8f27f5a465715dd8", + "9dbe02ef5fac41cfaee3d02946e65c88", + "37faa87ad03a4271992c21ce6a629e18", + "570c547e48cd421b814b2c5e028e4c0b", + "b173768580fc4c0a8e3abf272e4c363a", + "e57d1620f0a9427b85d8b4885ef4e8e3", + "5dd4474df70743498b616608182714dd", + "cc907676a65f4ad1bf68a77b4a00e89b", + "a34abc3b118e4305951a466919c28ff6", + "a77ccfcdb90146c7a63b4b2d232bc494", + "f7313e6e3a27475993cab3961d6ae363", + "39b47fad9c554839868fe9e4bbf7def2", + "14e9511ea0bd44c49f0cf3abf1a6d40e", + "a4ea8e0c4cac4d5e88b7e3f527e4fe90", + "571afc0f4b2840c9830d6b5a307ed1f9", + "6ec593cab5b64f0ea638bb175b9daa5c", + "77a52aed00ae408bb24524880e19ec8a", + "0b2de4b29b4b44fe9d96361a40c793d0", + "3c5b5fb1a5ac468a89c1058bd90cfb58", + "e53e0a2a240e43cfa562c89b3d703dea", + "35966343cf9249ef8bc028a0d5c5f97d", + "e36a37e0d41c47ccb8bc6d56c19fb17c", + "279ccf7de43847a1a6579c9182a46cc8", + "41b5d6ab0b7d43c790a55f125c0e7494" + ] + }, + "id": "UsQJdAgkLp9n", + "outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869" + }, + "outputs": [], + "source": [ + "device = \"cuda\"\n", + "# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n", + "checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n", + "mimi = checkpoint_info.get_mimi(device=device)\n", + "text_tokenizer = checkpoint_info.get_text_tokenizer()\n", + "lm = checkpoint_info.get_moshi(device=device)\n", + "in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n", + "in_pcms = torch.from_numpy(in_pcms).to(device=device)\n", + "\n", + "stt_config = checkpoint_info.stt_config\n", + "pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n", + "pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n", + "in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n", + "in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n", + "\n", + "state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n", + "text = state.run(in_pcms)\n", + "print(textwrap.fill(text, width=100))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "CIAXs9oaPrtj", + "outputId": "94cc208c-2454-4dd4-a64e-d79025144af5" + }, + "outputs": [], + "source": [ + "from IPython.display import Audio\n", + "\n", + "Audio(\"sample_fr_hibiki_crepes.mp3\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qkUZ6CBKOdTa" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "L4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}